← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_orchestration.c
Go to the documentation of this file.
1 /**
2  * @file ckernel_orchestration.c
3  *
4  * ===========================================================================
5  * LEGACY CODE - NOT USED IN v6.6
6  * ===========================================================================
7  *
8  * This file contains v6.5 orchestration code that is NO LONGER USED.
9  * It is kept for reference and potential future use but is NOT compiled
10  * into the v6.6 engine.
11  *
12  * v6.6 Architecture:
13  * - IR Lower 3 handles all orchestration via dataflow graph
14  * - Kernel dispatch via ckernel_codegen.c (for dynamically loaded kernels)
15  * - Memory planning via memory_planner_v6_6.py
16  *
17  * Contents of this file (NOT used):
18  * - ck_attention_flash_decode_wrapper: Flash attention wrapper (use
19  * mega_fused_attention_prefill/avx instead)
20  * - ck_quantized_gemm: Dispatcher for Q4_K, Q5_0, Q5_1, Q6_K, Q8_0
21  * (use kernel_maps/KERNEL_REGISTRY.json + codegen instead)
22  *
23  * To remove completely:
24  * 1. Delete this file
25  * 2. Remove from Makefile SRCS list
26  * 3. Remove ckernel_orchestration.h
27  *
28  * Last used: v6.5
29  * Deprecated: v6.6 (2026-02)
30  * ===========================================================================
31  */
32 
33 #include "ckernel_orchestration.h"
34 
35 #include "ckernel_engine.h"
36 #include "ckernel_dtype.h"
37 #include "ckernel_quant.h"
38 
39 #include <stddef.h>
40 #include <stdio.h>
41 #include <stdlib.h>
42 #include <string.h>
43 #include <math.h>
44 
45 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
46 #include <immintrin.h>
47 #endif
48 
49 /* ============================================================================
50  * TRUE FLASH ATTENTION (O(1) for decode)
51  *
52  * New implementation based on Tri Dao's Flash Attention algorithm.
53  * Provides O(1) complexity for decode instead of O(n).
54  *
55  * Reference: attention_flash_decode() in src/kernels/attention_flash_true.c
56  * ============================================================================ */
57 
58 /**
59  * @brief Wrapper to call TRUE flash attention from orchestration layer
60  *
61  * @param q_token Query token [H, D_h]
62  * @param k_cache Cached keys [T_k, H, D_h]
63  * @param v_cache Cached values [T_k, H, D_h]
64  * @param out_token Output [H, D_h]
65  * @param num_heads Number of heads
66  * @param num_kv_heads Number of KV heads (for GQA)
67  * @param kv_tokens Number of tokens in KV cache
68  * @param cache_capacity Cache capacity
69  * @param head_dim Head dimension
70  * @param aligned_head_dim Aligned head dimension
71  */
73  const float *q_token,
74  const float *k_cache,
75  const float *v_cache,
76  float *out_token,
77  int num_heads,
78  int num_kv_heads,
79  int kv_tokens,
80  int cache_capacity,
81  int head_dim,
82  int aligned_head_dim)
83 {
84  if (!q_token || !k_cache || !v_cache || !out_token) {
85  return;
86  }
87  if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
88  return;
89  }
90  if (kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
91  return;
92  }
93 
94  static int use_strict = -1;
95  if (use_strict < 0) {
96  const char *env = getenv("CK_FLASH_ATTN_STRICT");
97  use_strict = (env && env[0] && env[0] != '0') ? 1 : 0;
98  }
99 
100  if (use_strict) {
102  k_cache,
103  v_cache,
104  out_token,
105  num_heads,
106  num_kv_heads,
107  kv_tokens,
108  cache_capacity,
109  head_dim,
110  aligned_head_dim);
111  return;
112  }
113 
114  // Scale factor: 1/sqrt(head_dim)
115  const float scale = 1.0f / sqrtf((float)head_dim);
116  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
117 
118 #pragma omp parallel for schedule(static) if(num_heads > 1)
119  for (int h = 0; h < num_heads; ++h) {
120  const int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
121  const float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
122  const float *k_head = k_cache + (size_t)kv_head * head_stride;
123  const float *v_head = v_cache + (size_t)kv_head * head_stride;
124  float *out_head = out_token + (size_t)h * (size_t)aligned_head_dim;
125 
126  // Use aligned_head_dim as D_h so per-token stride matches the cache layout.
127  attention_flash_decode(out_head,
128  q_head,
129  k_head,
130  v_head,
131  1,
132  kv_tokens,
133  1,
134  aligned_head_dim,
135  scale);
136  }
137 }
138 
139 void ck_residual_add_token_major(const float *a,
140  const float *b,
141  float *out,
142  int tokens,
143  int aligned_embed_dim)
144 {
145  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
146  for (size_t i = 0; i < total; ++i) {
147  out[i] = a[i] + b[i];
148  }
149 }
150 
151 void ck_residual_add_backward(const float *d_out,
152  float *d_a,
153  float *d_b,
154  int tokens,
155  int aligned_embed_dim)
156 {
157  if (!d_out || !d_a || !d_b) {
158  return;
159  }
160  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
161  for (size_t i = 0; i < total; ++i) {
162  float v = d_out[i];
163  d_a[i] = v;
164  d_b[i] = v;
165  }
166 }
167 
168 void ck_qkv_project_head_major(const float *input,
169  const float *wq, const float *bq,
170  const float *wk, const float *bk,
171  const float *wv, const float *bv,
172  float *q, float *k, float *v,
173  int tokens,
174  int kv_stride_tokens,
175  int aligned_embed_dim,
176  int num_heads,
177  int num_kv_heads,
178  int aligned_head_dim)
179 {
180  if (!input || !wq || !wk || !wv || !q || !k || !v) {
181  return;
182  }
183  if (kv_stride_tokens < tokens) {
184  return;
185  }
186 
187  size_t head_weight_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
188  size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
189  size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
190 
191  for (int h = 0; h < num_heads; ++h) {
192  const float *wq_h = wq + (size_t)h * head_weight_stride;
193  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
194  float *q_h = q + (size_t)h * q_head_stride;
195 
196  gemm_blocked_serial(input, wq_h, bq_h, q_h,
197  tokens, aligned_head_dim, aligned_embed_dim);
198  }
199 
200  for (int h = 0; h < num_kv_heads; ++h) {
201  const float *wk_h = wk + (size_t)h * head_weight_stride;
202  const float *wv_h = wv + (size_t)h * head_weight_stride;
203 
204  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
205  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
206 
207  float *k_h = k + (size_t)h * kv_head_stride;
208  float *v_h = v + (size_t)h * kv_head_stride;
209 
210  gemm_blocked_serial(input, wk_h, bk_h, k_h,
211  tokens, aligned_head_dim, aligned_embed_dim);
212  gemm_blocked_serial(input, wv_h, bv_h, v_h,
213  tokens, aligned_head_dim, aligned_embed_dim);
214  }
215 }
216 
217 static int ck_layer_debug_enabled(void)
218 {
219  static int cached = -2;
220  if (cached != -2) {
221  return cached;
222  }
223  const char *env = getenv("CK_LAYER_DEBUG");
224  if (env && (env[0] == '1' || env[0] == 'y' || env[0] == 'Y')) {
225  cached = 1;
226  } else {
227  cached = 0;
228  }
229  return cached;
230 }
231 
232 static void ck_debug_check_buffer(const char *stage, const float *buf, int size)
233 {
234  if (!ck_layer_debug_enabled() || !buf) {
235  return;
236  }
237  int nan_count = 0, inf_count = 0;
238  float min_val = 1e38f, max_val = -1e38f;
239  for (int i = 0; i < size; ++i) {
240  float v = buf[i];
241  if (isnan(v)) {
242  nan_count++;
243  } else if (isinf(v)) {
244  inf_count++;
245  } else {
246  if (v < min_val) min_val = v;
247  if (v > max_val) max_val = v;
248  }
249  }
250  if (nan_count > 0 || inf_count > 0) {
251  fprintf(stderr, "[LAYER_DEBUG] %-30s size=%5d nan=%d inf=%d\n",
252  stage, size, nan_count, inf_count);
253  } else {
254  fprintf(stderr, "[LAYER_DEBUG] %-30s size=%5d range=[%.3e, %.3e]\n",
255  stage, size, min_val, max_val);
256  }
257 }
258 
259 static void ck_debug_check_q8k(const char *stage, const void *q8_buf, int num_blocks)
260 {
261  if (!ck_layer_debug_enabled() || !q8_buf) {
262  return;
263  }
264  const block_q8_K *blocks = (const block_q8_K *)q8_buf;
265  int nan_scale = 0, inf_scale = 0;
266  float min_d = 1e38f, max_d = -1e38f;
267  for (int i = 0; i < num_blocks; ++i) {
268  float d = blocks[i].d;
269  if (isnan(d)) {
270  nan_scale++;
271  } else if (isinf(d)) {
272  inf_scale++;
273  } else {
274  if (d < min_d) min_d = d;
275  if (d > max_d) max_d = d;
276  }
277  }
278  if (nan_scale > 0 || inf_scale > 0) {
279  fprintf(stderr, "[LAYER_DEBUG] %-30s blocks=%d nan_scale=%d inf_scale=%d\n",
280  stage, num_blocks, nan_scale, inf_scale);
281  } else {
282  fprintf(stderr, "[LAYER_DEBUG] %-30s blocks=%d scale_range=[%.3e, %.3e]\n",
283  stage, num_blocks, min_d, max_d);
284  }
285 }
286 
287 static void ck_debug_check_q4k_weights(const char *stage, const void *q4_buf, int num_blocks)
288 {
289  if (!ck_layer_debug_enabled() || !q4_buf) {
290  return;
291  }
292  const block_q4_K *blocks = (const block_q4_K *)q4_buf;
293  int nan_d = 0, nan_dmin = 0;
294  float min_d = 1e38f, max_d = -1e38f;
295  for (int i = 0; i < num_blocks; ++i) {
296  float d = CK_FP16_TO_FP32(blocks[i].d);
297  float dm = CK_FP16_TO_FP32(blocks[i].dmin);
298  if (isnan(d)) nan_d++;
299  if (isnan(dm)) nan_dmin++;
300  if (!isnan(d) && !isinf(d)) {
301  if (d < min_d) min_d = d;
302  if (d > max_d) max_d = d;
303  }
304  }
305  if (nan_d > 0 || nan_dmin > 0) {
306  fprintf(stderr, "[LAYER_DEBUG] %-30s blocks=%d nan_d=%d nan_dmin=%d\n",
307  stage, num_blocks, nan_d, nan_dmin);
308  } else {
309  fprintf(stderr, "[LAYER_DEBUG] %-30s blocks=%d d_range=[%.3e, %.3e]\n",
310  stage, num_blocks, min_d, max_d);
311  }
312 }
313 
315 {
316  static int cached = -2;
317  if (cached != -2) {
318  return cached;
319  }
320 
321  const char *env = getenv("CK_Q8K_ACTIVATIONS");
322  if (!env || !env[0]) {
323  cached = ck_strict_parity_enabled() ? 0 : 1;
324  return cached;
325  }
326  if (env[0] == '0' || env[0] == 'n' || env[0] == 'N' ||
327  env[0] == 'f' || env[0] == 'F') {
328  cached = 0;
329  } else {
330  cached = 1;
331  }
332  return cached;
333 }
334 
335 void ck_gemm_nt_quant(const float *A,
336  const void *B,
337  const float *bias,
338  float *C,
339  int M, int N, int K,
340  CKDataType dtype)
341 {
342  switch (dtype) {
343  case CK_DT_FP32:
344  gemm_blocked_serial(A, (const float *)B, bias, C, M, N, K);
345  break;
346  case CK_DT_Q4_K:
347  gemm_nt_q4_k(A, B, bias, C, M, N, K);
348  break;
349  case CK_DT_Q6_K:
350  gemm_nt_q6_k(A, B, bias, C, M, N, K);
351  break;
352  case CK_DT_Q4_0:
353  gemm_nt_q4_0(A, B, bias, C, M, N, K);
354  break;
355  case CK_DT_Q4_1:
356  gemm_nt_q4_1(A, B, bias, C, M, N, K);
357  break;
358  case CK_DT_Q5_0:
359  gemm_nt_q5_0(A, B, bias, C, M, N, K);
360  break;
361  case CK_DT_Q5_1:
362  gemm_nt_q5_1(A, B, bias, C, M, N, K);
363  break;
364  case CK_DT_Q8_0:
365  gemm_nt_q8_0(A, B, bias, C, M, N, K);
366  break;
367  default:
368  break;
369  }
370 }
371 
372 /* ============================================================================
373  * Q4_K (Q4_K_M) forward-only paths
374  * ============================================================================
375  *
376  * These helpers keep the same activation layouts as the fp32 code paths, but
377  * accept weight matrices stored as GGML-compatible Q4_K blocks. This is meant
378  * for weight-only quantized inference: activations remain fp32 by default, but
379  * decode can switch to Q8_K activations via CK_Q8K_ACTIVATIONS=1.
380  *
381  * Important constraints:
382  * - Q4_K kernels require K (the input dimension) to be a multiple of 256.
383  * - For attention output projection we assume the concatenated head vector
384  * has length aligned_embed_dim (i.e., num_heads * aligned_head_dim matches).
385  */
386 
387 static void ck_qkv_project_head_major_q4_k(const float *input,
388  const void *wq, const float *bq,
389  const void *wk, const float *bk,
390  const void *wv, const float *bv,
391  float *q, float *k, float *v,
392  int tokens,
393  int kv_stride_tokens,
394  int aligned_embed_dim,
395  int num_heads,
396  int num_kv_heads,
397  int aligned_head_dim)
398 {
399  if (!input || !wq || !wk || !wv || !q || !k || !v) {
400  return;
401  }
402  if (kv_stride_tokens < tokens) {
403  return;
404  }
405 
406  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
407  const size_t head_w_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
408  const size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
409  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
410 
411  const uint8_t *wq_bytes = (const uint8_t *)wq;
412  const uint8_t *wk_bytes = (const uint8_t *)wk;
413  const uint8_t *wv_bytes = (const uint8_t *)wv;
414 
415  for (int h = 0; h < num_heads; ++h) {
416  const void *wq_h = wq_bytes + (size_t)h * head_w_bytes;
417  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
418  float *q_h = q + (size_t)h * q_head_stride;
419 
420  gemm_nt_q4_k(input, wq_h, bq_h, q_h,
421  tokens, aligned_head_dim, aligned_embed_dim);
422  }
423 
424  for (int h = 0; h < num_kv_heads; ++h) {
425  const void *wk_h = wk_bytes + (size_t)h * head_w_bytes;
426  const void *wv_h = wv_bytes + (size_t)h * head_w_bytes;
427 
428  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
429  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
430 
431  float *k_h = k + (size_t)h * kv_head_stride;
432  float *v_h = v + (size_t)h * kv_head_stride;
433 
434  gemm_nt_q4_k(input, wk_h, bk_h, k_h,
435  tokens, aligned_head_dim, aligned_embed_dim);
436  gemm_nt_q4_k(input, wv_h, bv_h, v_h,
437  tokens, aligned_head_dim, aligned_embed_dim);
438  }
439 }
440 
441 static void ck_qkv_project_head_major_quant(const float *input,
442  const void *wq, const float *bq, CKDataType wq_dtype,
443  const void *wk, const float *bk, CKDataType wk_dtype,
444  const void *wv, const float *bv, CKDataType wv_dtype,
445  float *q, float *k, float *v,
446  int tokens,
447  int kv_stride_tokens,
448  int aligned_embed_dim,
449  int num_heads,
450  int num_kv_heads,
451  int aligned_head_dim)
452 {
453  if (!input || !wq || !wk || !wv || !q || !k || !v) {
454  return;
455  }
456  if (kv_stride_tokens < tokens) {
457  return;
458  }
459 
460  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
461  const size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
462  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
463 
464  const size_t wq_head_bytes = ck_dtype_row_bytes(wq_dtype, head_w_elems);
465  const size_t wk_head_bytes = ck_dtype_row_bytes(wk_dtype, head_w_elems);
466  const size_t wv_head_bytes = ck_dtype_row_bytes(wv_dtype, head_w_elems);
467 
468  const uint8_t *wq_bytes = (const uint8_t *)wq;
469  const uint8_t *wk_bytes = (const uint8_t *)wk;
470  const uint8_t *wv_bytes = (const uint8_t *)wv;
471 
472  for (int h = 0; h < num_heads; ++h) {
473  const void *wq_h = (wq_dtype == CK_DT_FP32)
474  ? (const void *)((const float *)wq + (size_t)h * head_w_elems)
475  : (const void *)(wq_bytes + (size_t)h * wq_head_bytes);
476  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
477  float *q_h = q + (size_t)h * q_head_stride;
478 
479  ck_gemm_nt_quant(input, wq_h, bq_h, q_h,
480  tokens, aligned_head_dim, aligned_embed_dim, wq_dtype);
481  }
482 
483  for (int h = 0; h < num_kv_heads; ++h) {
484  const void *wk_h = (wk_dtype == CK_DT_FP32)
485  ? (const void *)((const float *)wk + (size_t)h * head_w_elems)
486  : (const void *)(wk_bytes + (size_t)h * wk_head_bytes);
487  const void *wv_h = (wv_dtype == CK_DT_FP32)
488  ? (const void *)((const float *)wv + (size_t)h * head_w_elems)
489  : (const void *)(wv_bytes + (size_t)h * wv_head_bytes);
490 
491  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
492  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
493 
494  float *k_h = k + (size_t)h * kv_head_stride;
495  float *v_h = v + (size_t)h * kv_head_stride;
496 
497  ck_gemm_nt_quant(input, wk_h, bk_h, k_h,
498  tokens, aligned_head_dim, aligned_embed_dim, wk_dtype);
499  ck_gemm_nt_quant(input, wv_h, bv_h, v_h,
500  tokens, aligned_head_dim, aligned_embed_dim, wv_dtype);
501  }
502 }
503 
504 static void ck_attention_project_head_major_q4_k(const float *attn_out,
505  const void *wo,
506  const float *bo,
507  float *out,
508  float *scratch,
509  int tokens,
510  int aligned_embed_dim,
511  int num_heads,
512  int aligned_head_dim)
513 {
514  if (!attn_out || !wo || !out || !scratch) {
515  return;
516  }
517 
518  /* Flatten head-major [H, T, ad] into token-major [T, H*ad] */
519  const int K = num_heads * aligned_head_dim;
520  if (K != aligned_embed_dim) {
521  return;
522  }
523 
524  const size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
525 
526  for (int t = 0; t < tokens; ++t) {
527  float *dst = scratch + (size_t)t * (size_t)aligned_embed_dim;
528  for (int h = 0; h < num_heads; ++h) {
529  const float *src = attn_out + (size_t)h * head_in_stride + (size_t)t * (size_t)aligned_head_dim;
530  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
531  src,
532  (size_t)aligned_head_dim * sizeof(float));
533  }
534  }
535 
536  gemm_nt_q4_k(scratch, wo, bo, out,
537  tokens, aligned_embed_dim, aligned_embed_dim);
538 }
539 
540 static void ck_attention_project_head_major_quant(const float *attn_out,
541  const void *wo,
542  const float *bo,
543  float *out,
544  float *scratch,
545  int tokens,
546  int aligned_embed_dim,
547  int num_heads,
548  int aligned_head_dim,
549  CKDataType wo_dtype)
550 {
551  if (!attn_out || !wo || !out || !scratch) {
552  return;
553  }
554 
555  if (wo_dtype == CK_DT_FP32) {
557  (const float *)wo,
558  bo,
559  out,
560  scratch,
561  tokens,
562  aligned_embed_dim,
563  num_heads,
564  aligned_head_dim);
565  return;
566  }
567 
568  const int K = num_heads * aligned_head_dim;
569  if (K != aligned_embed_dim) {
570  return;
571  }
572 
573  const size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
574 
575  for (int t = 0; t < tokens; ++t) {
576  float *dst = scratch + (size_t)t * (size_t)aligned_embed_dim;
577  for (int h = 0; h < num_heads; ++h) {
578  const float *src = attn_out + (size_t)h * head_in_stride + (size_t)t * (size_t)aligned_head_dim;
579  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
580  src,
581  (size_t)aligned_head_dim * sizeof(float));
582  }
583  }
584 
585  ck_gemm_nt_quant(scratch, wo, bo, out,
586  tokens, aligned_embed_dim, aligned_embed_dim, wo_dtype);
587 }
588 
589 static void ck_mlp_swiglu_forward_q4_k(const float *input,
590  const void *w1,
591  const float *b1,
592  const void *w2,
593  const float *b2,
594  float *fc1_out,
595  float *swiglu_out,
596  float *output,
597  int tokens,
598  int aligned_embed_dim,
599  int aligned_intermediate_dim)
600 {
601  int up_dim = 2 * aligned_intermediate_dim;
602  gemm_nt_q4_k(input, w1, b1, fc1_out,
603  tokens, up_dim, aligned_embed_dim);
604 
605  swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
606 
607  gemm_nt_q4_k(swiglu_out, w2, b2, output,
608  tokens, aligned_embed_dim, aligned_intermediate_dim);
609 }
610 
611 static void ck_mlp_swiglu_forward_quant(const float *input,
612  const void *w1,
613  const float *b1,
614  CKDataType w1_dtype,
615  const void *w2,
616  const float *b2,
617  CKDataType w2_dtype,
618  float *fc1_out,
619  float *swiglu_out,
620  float *output,
621  int tokens,
622  int aligned_embed_dim,
623  int aligned_intermediate_dim)
624 {
625  int up_dim = 2 * aligned_intermediate_dim;
626  ck_gemm_nt_quant(input, w1, b1, fc1_out,
627  tokens, up_dim, aligned_embed_dim, w1_dtype);
628 
629  swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
630 
631  ck_gemm_nt_quant(swiglu_out, w2, b2, output,
632  tokens, aligned_embed_dim, aligned_intermediate_dim, w2_dtype);
633 }
634 
635 static void ck_mlp_swiglu_forward_q4_k_q8_k(const float *input,
636  const void *w1,
637  const float *b1,
638  const void *w2,
639  const float *b2,
640  float *fc1_out,
641  float *swiglu_out,
642  float *output,
643  int aligned_embed_dim,
644  int aligned_intermediate_dim)
645 {
646  if (!input || !w1 || !w2 || !fc1_out || !swiglu_out || !output) {
647  return;
648  }
649  if ((aligned_embed_dim % QK_K) != 0 || (aligned_intermediate_dim % QK_K) != 0) {
650  return;
651  }
652 
653  const int up_dim = 2 * aligned_intermediate_dim;
654  const int q8_blocks_embed = aligned_embed_dim / QK_K;
655  const int q8_blocks_inter = aligned_intermediate_dim / QK_K;
656  const int q8_blocks_max = (q8_blocks_embed > q8_blocks_inter) ? q8_blocks_embed : q8_blocks_inter;
657  block_q8_K q8_buf[q8_blocks_max];
658 
659  quantize_row_q8_k(input, q8_buf, aligned_embed_dim);
660  gemm_nt_q4_k_q8_k(q8_buf, w1, b1, fc1_out,
661  /*M=*/1, /*N=*/up_dim, /*K=*/aligned_embed_dim);
662 
663  swiglu_forward(fc1_out, swiglu_out, /*tokens=*/1, aligned_intermediate_dim);
664 
665  quantize_row_q8_k(swiglu_out, q8_buf, aligned_intermediate_dim);
666  gemm_nt_q4_k_q8_k(q8_buf, w2, b2, output,
667  /*M=*/1, /*N=*/aligned_embed_dim, /*K=*/aligned_intermediate_dim);
668 }
669 
670 static void ck_qkv_project_head_major_ref(const float *input,
671  const float *wq, const float *bq,
672  const float *wk, const float *bk,
673  const float *wv, const float *bv,
674  float *q, float *k, float *v,
675  int tokens,
676  int kv_stride_tokens,
677  int aligned_embed_dim,
678  int num_heads,
679  int num_kv_heads,
680  int aligned_head_dim)
681 {
682  if (!input || !wq || !wk || !wv || !q || !k || !v) {
683  return;
684  }
685  if (kv_stride_tokens < tokens) {
686  return;
687  }
688 
689  size_t head_weight_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
690  size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
691  size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
692 
693  for (int h = 0; h < num_heads; ++h) {
694  const float *wq_h = wq + (size_t)h * head_weight_stride;
695  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
696  float *q_h = q + (size_t)h * q_head_stride;
697 
698  gemm_naive_parallel(input, wq_h, bq_h, q_h,
699  tokens, aligned_head_dim, aligned_embed_dim);
700  }
701 
702  for (int h = 0; h < num_kv_heads; ++h) {
703  const float *wk_h = wk + (size_t)h * head_weight_stride;
704  const float *wv_h = wv + (size_t)h * head_weight_stride;
705 
706  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
707  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
708 
709  float *k_h = k + (size_t)h * kv_head_stride;
710  float *v_h = v + (size_t)h * kv_head_stride;
711 
712  gemm_naive_parallel(input, wk_h, bk_h, k_h,
713  tokens, aligned_head_dim, aligned_embed_dim);
714  gemm_naive_parallel(input, wv_h, bv_h, v_h,
715  tokens, aligned_head_dim, aligned_embed_dim);
716  }
717 }
718 
719 static void ck_add_inplace(float *dst,
720  const float *src,
721  int tokens,
722  int aligned_embed_dim)
723 {
724  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
725  for (size_t i = 0; i < total; ++i) {
726  dst[i] += src[i];
727  }
728 }
729 
730 void ck_attention_project_head_major(const float *attn_out,
731  const float *wo,
732  const float *bo,
733  float *out,
734  float *scratch,
735  int tokens,
736  int aligned_embed_dim,
737  int num_heads,
738  int aligned_head_dim)
739 {
740  if (!attn_out || !wo || !out) {
741  return;
742  }
743  if (num_heads > 1 && !scratch) {
744  return;
745  }
746 
747  size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
748  size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
749 
750  for (int h = 0; h < num_heads; ++h) {
751  const float *head_in = attn_out + (size_t)h * head_in_stride;
752  const float *wo_h = wo + (size_t)h * head_weight_stride;
753 
754  if (h == 0) {
755  gemm_blocked_serial(head_in, wo_h, bo, out,
756  tokens, aligned_embed_dim, aligned_head_dim);
757  } else {
758  gemm_blocked_serial(head_in, wo_h, NULL, scratch,
759  tokens, aligned_embed_dim, aligned_head_dim);
760  ck_add_inplace(out, scratch, tokens, aligned_embed_dim);
761  }
762  }
763 }
764 
765 static void ck_attention_project_head_major_ref(const float *attn_out,
766  const float *wo,
767  const float *bo,
768  float *out,
769  float *scratch,
770  int tokens,
771  int aligned_embed_dim,
772  int num_heads,
773  int aligned_head_dim)
774 {
775  if (!attn_out || !wo || !out) {
776  return;
777  }
778  if (num_heads > 1 && !scratch) {
779  return;
780  }
781 
782  size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
783  size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
784 
785  for (int h = 0; h < num_heads; ++h) {
786  const float *head_in = attn_out + (size_t)h * head_in_stride;
787  const float *wo_h = wo + (size_t)h * head_weight_stride;
788 
789  if (h == 0) {
790  gemm_naive_parallel(head_in, wo_h, bo, out,
791  tokens, aligned_embed_dim, aligned_head_dim);
792  } else {
793  gemm_naive_parallel(head_in, wo_h, NULL, scratch,
794  tokens, aligned_embed_dim, aligned_head_dim);
795  ck_add_inplace(out, scratch, tokens, aligned_embed_dim);
796  }
797  }
798 }
799 
801  const float *attn_out,
802  const float *wo,
803  float *d_attn_out,
804  float *d_wo,
805  float *d_bo,
806  int tokens,
807  int aligned_embed_dim,
808  int num_heads,
809  int aligned_head_dim)
810 {
811  if (!d_out || !attn_out || !wo || !d_attn_out || !d_wo || !d_bo) {
812  return;
813  }
814 
815  // Bias gradient: sum over tokens once (bias is applied once in forward).
816  for (int d = 0; d < aligned_embed_dim; ++d) {
817  d_bo[d] = 0.0f;
818  }
819  for (int t = 0; t < tokens; ++t) {
820  const float *row = d_out + (size_t)t * (size_t)aligned_embed_dim;
821  for (int d = 0; d < aligned_embed_dim; ++d) {
822  d_bo[d] += row[d];
823  }
824  }
825 
826  size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
827  size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
828 
829  float *tmp_b = (float *)calloc((size_t)aligned_embed_dim, sizeof(float));
830  if (!tmp_b) {
831  return;
832  }
833 
834  for (int h = 0; h < num_heads; ++h) {
835  const float *head_in = attn_out + (size_t)h * head_in_stride;
836  const float *wo_h = wo + (size_t)h * head_weight_stride;
837  float *d_head_in = d_attn_out + (size_t)h * head_in_stride;
838  float *d_wo_h = d_wo + (size_t)h * head_weight_stride;
839 
840  memset(tmp_b, 0, (size_t)aligned_embed_dim * sizeof(float));
841  fc2_backward_kernel(d_out,
842  head_in,
843  wo_h,
844  d_head_in,
845  d_wo_h,
846  tmp_b,
847  tokens,
848  aligned_head_dim,
849  aligned_embed_dim,
850  1);
851  }
852 
853  free(tmp_b);
854 }
855 
857  const float *d_k,
858  const float *d_v,
859  const float *input,
860  const float *wq,
861  const float *bq,
862  const float *wk,
863  const float *bk,
864  const float *wv,
865  const float *bv,
866  float *d_input,
867  float *d_wq,
868  float *d_bq,
869  float *d_wk,
870  float *d_bk,
871  float *d_wv,
872  float *d_bv,
873  float *scratch,
874  int tokens,
875  int aligned_embed_dim,
876  int num_heads,
877  int num_kv_heads,
878  int aligned_head_dim,
879  int num_threads)
880 {
881  if (!d_q || !d_k || !d_v || !input || !wq || !wk || !wv ||
882  !d_input || !d_wq || !d_bq || !d_wk || !d_bk || !d_wv || !d_bv || !scratch) {
883  return;
884  }
885 
886  size_t total_in = (size_t)tokens * (size_t)aligned_embed_dim;
887  for (size_t i = 0; i < total_in; ++i) {
888  d_input[i] = 0.0f;
889  }
890 
891  size_t head_weight_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
892  size_t head_out_stride = (size_t)tokens * (size_t)aligned_head_dim;
893 
894  for (int h = 0; h < num_heads; ++h) {
895  const float *d_q_h = d_q + (size_t)h * head_out_stride;
896  const float *wq_h = wq + (size_t)h * head_weight_stride;
897  float *d_wq_h = d_wq + (size_t)h * head_weight_stride;
898  float *d_bq_h = d_bq + (size_t)h * (size_t)aligned_head_dim;
899 
900  fc2_backward_kernel(d_q_h,
901  input,
902  wq_h,
903  scratch,
904  d_wq_h,
905  d_bq_h,
906  tokens,
907  aligned_embed_dim,
908  aligned_head_dim,
909  num_threads);
910  ck_add_inplace(d_input, scratch, tokens, aligned_embed_dim);
911  }
912 
913  for (int h = 0; h < num_kv_heads; ++h) {
914  const float *d_k_h = d_k + (size_t)h * head_out_stride;
915  const float *d_v_h = d_v + (size_t)h * head_out_stride;
916 
917  const float *wk_h = wk + (size_t)h * head_weight_stride;
918  const float *wv_h = wv + (size_t)h * head_weight_stride;
919 
920  float *d_wk_h = d_wk + (size_t)h * head_weight_stride;
921  float *d_wv_h = d_wv + (size_t)h * head_weight_stride;
922 
923  float *d_bk_h = d_bk + (size_t)h * (size_t)aligned_head_dim;
924  float *d_bv_h = d_bv + (size_t)h * (size_t)aligned_head_dim;
925 
926  fc2_backward_kernel(d_k_h,
927  input,
928  wk_h,
929  scratch,
930  d_wk_h,
931  d_bk_h,
932  tokens,
933  aligned_embed_dim,
934  aligned_head_dim,
935  num_threads);
936  ck_add_inplace(d_input, scratch, tokens, aligned_embed_dim);
937 
938  fc2_backward_kernel(d_v_h,
939  input,
940  wv_h,
941  scratch,
942  d_wv_h,
943  d_bv_h,
944  tokens,
945  aligned_embed_dim,
946  aligned_head_dim,
947  num_threads);
948  ck_add_inplace(d_input, scratch, tokens, aligned_embed_dim);
949  }
950 }
951 
952 void ck_mlp_swiglu_forward(const float *input,
953  const float *w1,
954  const float *b1,
955  const float *w2,
956  const float *b2,
957  float *fc1_out,
958  float *swiglu_out,
959  float *output,
960  int tokens,
961  int aligned_embed_dim,
962  int aligned_intermediate_dim)
963 {
964  int up_dim = 2 * aligned_intermediate_dim;
965  gemm_blocked_serial(input, w1, b1, fc1_out,
966  tokens, up_dim, aligned_embed_dim);
967 
968  swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
969 
970  gemm_blocked_serial(swiglu_out, w2, b2, output,
971  tokens, aligned_embed_dim, aligned_intermediate_dim);
972 }
973 
974 static void ck_mlp_swiglu_forward_ref(const float *input,
975  const float *w1,
976  const float *b1,
977  const float *w2,
978  const float *b2,
979  float *fc1_out,
980  float *swiglu_out,
981  float *output,
982  int tokens,
983  int aligned_embed_dim,
984  int aligned_intermediate_dim)
985 {
986  int up_dim = 2 * aligned_intermediate_dim;
987  gemm_naive_parallel(input, w1, b1, fc1_out,
988  tokens, up_dim, aligned_embed_dim);
989 
990  swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
991 
992  gemm_naive_parallel(swiglu_out, w2, b2, output,
993  tokens, aligned_embed_dim, aligned_intermediate_dim);
994 }
995 
997 {
998  if (!p) {
999  return;
1000  }
1001 
1003  p->ln1_gamma,
1004  p->ln1_out,
1005  p->ln1_rstd,
1006  p->tokens,
1007  p->embed_dim,
1008  p->aligned_embed_dim,
1009  p->eps);
1010 
1012  p->wq, p->bq,
1013  p->wk, p->bk,
1014  p->wv, p->bv,
1015  p->q, p->k, p->v,
1016  p->tokens,
1017  p->tokens,
1018  p->aligned_embed_dim,
1019  p->num_heads,
1020  p->num_kv_heads,
1021  p->aligned_head_dim);
1022 
1023  if (p->rope_cos && p->rope_sin) {
1024  rope_forward_qk(p->q,
1025  p->k,
1026  p->rope_cos,
1027  p->rope_sin,
1028  p->num_heads,
1029  p->num_kv_heads,
1030  p->tokens,
1031  p->head_dim,
1032  p->aligned_head_dim,
1033  p->rope_pos_offset);
1034  }
1035 
1036  if (p->scores) {
1038  p->k,
1039  p->v,
1040  p->scores,
1041  p->attn_out,
1042  p->num_heads,
1043  p->num_kv_heads,
1044  p->tokens,
1045  p->head_dim,
1046  p->aligned_head_dim,
1048  } else {
1050  p->k,
1051  p->v,
1052  p->attn_out,
1053  p->num_heads,
1054  p->num_kv_heads,
1055  p->tokens,
1056  p->head_dim,
1057  p->aligned_head_dim);
1058  }
1059 
1061  p->wo,
1062  p->bo,
1063  p->proj_tmp,
1064  p->proj_scratch,
1065  p->tokens,
1066  p->aligned_embed_dim,
1067  p->num_heads,
1068  p->aligned_head_dim);
1069 
1071  p->proj_tmp,
1072  p->residual1,
1073  p->tokens,
1074  p->aligned_embed_dim);
1075 
1077  p->ln2_gamma,
1078  p->ln2_out,
1079  p->ln2_rstd,
1080  p->tokens,
1081  p->embed_dim,
1082  p->aligned_embed_dim,
1083  p->eps);
1084 
1086  p->w1,
1087  p->b1,
1088  p->w2,
1089  p->b2,
1090  p->fc1_out,
1091  p->swiglu_out,
1092  p->mlp_out,
1093  p->tokens,
1094  p->aligned_embed_dim,
1096 
1098  p->mlp_out,
1099  p->output,
1100  p->tokens,
1101  p->aligned_embed_dim);
1102 }
1103 
1105 {
1106  if (!p) {
1107  return;
1108  }
1109 
1111  p->ln1_gamma,
1112  p->ln1_out,
1113  p->ln1_rstd,
1114  p->tokens,
1115  p->embed_dim,
1116  p->aligned_embed_dim,
1117  p->eps);
1118 
1120  p->wq, p->bq,
1121  p->wk, p->bk,
1122  p->wv, p->bv,
1123  p->q, p->k, p->v,
1124  p->tokens,
1125  p->tokens,
1126  p->aligned_embed_dim,
1127  p->num_heads,
1128  p->num_kv_heads,
1129  p->aligned_head_dim);
1130 
1131  if (p->rope_cos && p->rope_sin) {
1132  rope_forward_qk(p->q,
1133  p->k,
1134  p->rope_cos,
1135  p->rope_sin,
1136  p->num_heads,
1137  p->num_kv_heads,
1138  p->tokens,
1139  p->head_dim,
1140  p->aligned_head_dim,
1141  p->rope_pos_offset);
1142  }
1143 
1144  if (p->scores) {
1146  p->k,
1147  p->v,
1148  p->scores,
1149  p->attn_out,
1150  p->num_heads,
1151  p->num_kv_heads,
1152  p->tokens,
1153  p->head_dim,
1154  p->aligned_head_dim,
1156  } else {
1158  p->k,
1159  p->v,
1160  p->attn_out,
1161  p->num_heads,
1162  p->num_kv_heads,
1163  p->tokens,
1164  p->head_dim,
1165  p->aligned_head_dim);
1166  }
1167 
1169  p->wo,
1170  p->bo,
1171  p->proj_tmp,
1172  p->proj_scratch,
1173  p->tokens,
1174  p->aligned_embed_dim,
1175  p->num_heads,
1176  p->aligned_head_dim);
1177 
1179  p->proj_tmp,
1180  p->residual1,
1181  p->tokens,
1182  p->aligned_embed_dim);
1183 
1185  p->ln2_gamma,
1186  p->ln2_out,
1187  p->ln2_rstd,
1188  p->tokens,
1189  p->embed_dim,
1190  p->aligned_embed_dim,
1191  p->eps);
1192 
1194  p->w1,
1195  p->b1,
1196  p->w2,
1197  p->b2,
1198  p->fc1_out,
1199  p->swiglu_out,
1200  p->mlp_out,
1201  p->tokens,
1202  p->aligned_embed_dim,
1204 
1206  p->mlp_out,
1207  p->output,
1208  p->tokens,
1209  p->aligned_embed_dim);
1210 }
1211 
1212 void ck_mlp_swiglu_forward_fused_token(const float *input_row,
1213  const float *w1,
1214  const float *b1,
1215  const float *w2,
1216  const float *b2,
1217  float *swiglu_row,
1218  float *output_row,
1219  int aligned_embed_dim,
1220  int aligned_intermediate_dim)
1221 {
1222  if (!input_row || !w1 || !w2 || !swiglu_row || !output_row) {
1223  return;
1224  }
1225 
1226  const float *w_gate = w1;
1227  const float *w_up = w1 + (size_t)aligned_intermediate_dim * (size_t)aligned_embed_dim;
1228  const float *b_gate = b1;
1229  const float *b_up = b1 ? (b1 + aligned_intermediate_dim) : NULL;
1230 
1231  gemm_swiglu_fused(input_row,
1232  w_gate,
1233  w_up,
1234  b_gate,
1235  b_up,
1236  swiglu_row,
1237  /*M=*/1,
1238  /*N=*/aligned_intermediate_dim,
1239  /*K=*/aligned_embed_dim);
1240 
1241  gemm_blocked_serial(swiglu_row, w2, b2, output_row,
1242  /*M=*/1,
1243  /*N=*/aligned_embed_dim,
1244  /*K=*/aligned_intermediate_dim);
1245 }
1246 
1247 void ck_mlp_swiglu_forward_fully_fused_token(const float *input_row,
1248  const float *w1,
1249  const float *b1,
1250  const float *w2,
1251  const float *b2,
1252  float *output_row,
1253  int aligned_embed_dim,
1254  int aligned_intermediate_dim)
1255 {
1256  if (!input_row || !w1 || !w2 || !output_row) {
1257  return;
1258  }
1259 
1260  // Split w1 into gate and up projections
1261  // w1 layout: [2 * aligned_intermediate_dim, aligned_embed_dim]
1262  // First half: W_gate [aligned_intermediate_dim, aligned_embed_dim]
1263  // Second half: W_up [aligned_intermediate_dim, aligned_embed_dim]
1264  const float *w_gate = w1;
1265  const float *w_up = w1 + (size_t)aligned_intermediate_dim * (size_t)aligned_embed_dim;
1266 
1267  // Split b1 into gate and up biases (if present)
1268  const float *b_gate = b1;
1269  const float *b_up = b1 ? (b1 + aligned_intermediate_dim) : NULL;
1270 
1271  // w2 is W_down: [aligned_embed_dim, aligned_intermediate_dim]
1272  const float *w_down = w2;
1273  const float *b_down = b2;
1274 
1275  // Call the fully fused kernel - eliminates DRAM round-trip for swiglu
1276  // Uses aligned dimensions since weights are stored with alignment padding
1277  fused_mlp_swiglu_decode_v2(input_row,
1278  w_gate,
1279  w_up,
1280  w_down,
1281  b_gate,
1282  b_up,
1283  b_down,
1284  output_row,
1285  aligned_embed_dim,
1286  aligned_intermediate_dim);
1287 }
1288 
1290  int token_index,
1291  int cache_capacity)
1292 {
1293  if (!p) {
1294  return;
1295  }
1296  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->ln1_out || !p->ln2_out ||
1297  !p->wq || !p->wk || !p->wv || !p->wo || !p->w1 || !p->w2 ||
1298  !p->k || !p->v ||
1299  !p->proj_tmp || !p->residual1 || !p->fc1_out || !p->swiglu_out || !p->mlp_out || !p->output) {
1300  return;
1301  }
1302  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
1303  return;
1304  }
1305  if (p->num_heads <= 0 || p->num_kv_heads <= 0 || p->aligned_head_dim <= 0) {
1306  return;
1307  }
1308 
1309  const int D = p->embed_dim;
1310  const int aligned_D = p->aligned_embed_dim;
1311  const int H = p->num_heads;
1312  const int H_kv = p->num_kv_heads;
1313  const int hd = p->head_dim;
1314  const int ad = p->aligned_head_dim;
1315  const int aligned_intermediate = p->aligned_intermediate_dim;
1316 
1317  /* Decode buffers are single-token; token_index only applies to KV cache. */
1318  const size_t token_slot = 0;
1319  const float *input_row = p->input + token_slot * (size_t)aligned_D;
1320  float *ln1_row = p->ln1_out + token_slot * (size_t)aligned_D;
1321  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
1322  float *proj_row = p->proj_tmp + token_slot * (size_t)aligned_D;
1323  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
1324  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
1325  float *out_row = p->output + token_slot * (size_t)aligned_D;
1326 
1327  float ln1_rstd_tmp = 0.0f;
1328  float ln2_rstd_tmp = 0.0f;
1329  float *ln1_rstd = p->ln1_rstd ? (p->ln1_rstd + token_slot) : &ln1_rstd_tmp;
1330  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
1331 
1332  // Scratch for a single token in head-major layout: [head, aligned_head_dim].
1333  size_t q_elems = (size_t)H * (size_t)ad;
1334  size_t kv_elems = (size_t)H_kv * (size_t)ad;
1335  float q_token[q_elems];
1336  float k_token[kv_elems];
1337  float v_token[kv_elems];
1338  float attn_token[q_elems];
1339 
1340  // LN1 / RMSNorm.
1341  rmsnorm_forward(input_row,
1342  p->ln1_gamma,
1343  ln1_row,
1344  ln1_rstd,
1345  /*tokens=*/1,
1346  D,
1347  aligned_D,
1348  p->eps);
1349 
1350  // Project Q/K/V for the new token.
1352  p->wq, p->bq,
1353  p->wk, p->bk,
1354  p->wv, p->bv,
1355  q_token, k_token, v_token,
1356  aligned_D,
1357  H,
1358  H_kv,
1359  ad);
1360 
1361  // RoPE for the new token at absolute position `p->rope_pos_offset`.
1362  if (p->rope_cos && p->rope_sin) {
1363  rope_forward_qk(q_token,
1364  k_token,
1365  p->rope_cos,
1366  p->rope_sin,
1367  H,
1368  H_kv,
1369  /*num_tokens=*/1,
1370  hd,
1371  ad,
1372  p->rope_pos_offset);
1373  }
1374 
1375  // Update KV cache (stores k/v for this token and clears padded lanes).
1376  kv_cache_write_head_major(k_token,
1377  v_token,
1378  p->k,
1379  p->v,
1380  H_kv,
1381  token_index,
1382  cache_capacity,
1383  hd,
1384  ad);
1385 
1386  // Decode attention for this token using the KV cache.
1388  p->k,
1389  p->v,
1390  attn_token,
1391  H,
1392  H_kv,
1393  /*kv_tokens=*/token_index + 1,
1394  cache_capacity,
1395  hd,
1396  ad);
1397 
1398  // Output projection (Wo) into token-major buffer (decode-specialized).
1400  p->wo,
1401  p->bo,
1402  proj_row,
1403  D,
1404  aligned_D,
1405  H,
1406  ad);
1407 
1408  // Residual + LN2 / RMSNorm.
1409  ck_residual_add_token_major(input_row,
1410  proj_row,
1411  residual_row,
1412  /*tokens=*/1,
1413  aligned_D);
1414 
1415  rmsnorm_forward(residual_row,
1416  p->ln2_gamma,
1417  ln2_row,
1418  ln2_rstd,
1419  /*tokens=*/1,
1420  D,
1421  aligned_D,
1422  p->eps);
1423 
1424  // MLP block for this token.
1425  int up_dim = 2 * aligned_intermediate;
1426  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
1427  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
1428 
1429  ck_mlp_swiglu_forward(ln2_row,
1430  p->w1,
1431  p->b1,
1432  p->w2,
1433  p->b2,
1434  fc1_row,
1435  swiglu_row,
1436  mlp_row,
1437  /*tokens=*/1,
1438  aligned_D,
1439  aligned_intermediate);
1440 
1441  // Final residual.
1442  ck_residual_add_token_major(residual_row,
1443  mlp_row,
1444  out_row,
1445  /*tokens=*/1,
1446  aligned_D);
1447 }
1448 
1450  int token_index,
1451  int cache_capacity)
1452 {
1453  if (!p) {
1454  return;
1455  }
1456  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->ln1_out || !p->ln2_out ||
1457  !p->wq || !p->wk || !p->wv || !p->wo || !p->w1 || !p->w2 ||
1458  !p->k || !p->v || !p->swiglu_out ||
1459  !p->proj_tmp || !p->residual1 || !p->mlp_out || !p->output) {
1460  return;
1461  }
1462  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
1463  return;
1464  }
1465  if (p->num_heads <= 0 || p->num_kv_heads <= 0 || p->aligned_head_dim <= 0) {
1466  return;
1467  }
1468 
1469  const int D = p->embed_dim;
1470  const int aligned_D = p->aligned_embed_dim;
1471  const int H = p->num_heads;
1472  const int H_kv = p->num_kv_heads;
1473  const int hd = p->head_dim;
1474  const int ad = p->aligned_head_dim;
1475  const int aligned_intermediate = p->aligned_intermediate_dim;
1476 
1477  /* Decode buffers are single-token; token_index only applies to KV cache. */
1478  const size_t token_slot = 0;
1479  const float *input_row = p->input + token_slot * (size_t)aligned_D;
1480  float *ln1_row = p->ln1_out + token_slot * (size_t)aligned_D;
1481  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
1482  float *proj_row = p->proj_tmp + token_slot * (size_t)aligned_D;
1483  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
1484  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
1485  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
1486  float *out_row = p->output + token_slot * (size_t)aligned_D;
1487 
1488  float ln1_rstd_tmp = 0.0f;
1489  float ln2_rstd_tmp = 0.0f;
1490  float *ln1_rstd = p->ln1_rstd ? (p->ln1_rstd + token_slot) : &ln1_rstd_tmp;
1491  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
1492 
1493  // Scratch for a single token in head-major layout: [head, aligned_head_dim].
1494  size_t q_elems = (size_t)H * (size_t)ad;
1495  size_t kv_elems = (size_t)H_kv * (size_t)ad;
1496  float q_token[q_elems];
1497  float k_token[kv_elems];
1498  float v_token[kv_elems];
1499  float attn_token[q_elems];
1500 
1501  // LN1 / RMSNorm.
1502  rmsnorm_forward(input_row,
1503  p->ln1_gamma,
1504  ln1_row,
1505  ln1_rstd,
1506  /*tokens=*/1,
1507  D,
1508  aligned_D,
1509  p->eps);
1510 
1511  // Project Q/K/V for the new token.
1513  p->wq, p->bq,
1514  p->wk, p->bk,
1515  p->wv, p->bv,
1516  q_token, k_token, v_token,
1517  aligned_D,
1518  H,
1519  H_kv,
1520  ad);
1521 
1522  // RoPE for the new token at absolute position `p->rope_pos_offset`.
1523  if (p->rope_cos && p->rope_sin) {
1524  rope_forward_qk(q_token,
1525  k_token,
1526  p->rope_cos,
1527  p->rope_sin,
1528  H,
1529  H_kv,
1530  /*num_tokens=*/1,
1531  hd,
1532  ad,
1533  p->rope_pos_offset);
1534  }
1535 
1536  // Update KV cache (stores k/v for this token and clears padded lanes).
1537  kv_cache_write_head_major(k_token,
1538  v_token,
1539  p->k,
1540  p->v,
1541  H_kv,
1542  token_index,
1543  cache_capacity,
1544  hd,
1545  ad);
1546 
1547  // Decode attention for this token using the KV cache.
1549  p->k,
1550  p->v,
1551  attn_token,
1552  H,
1553  H_kv,
1554  /*kv_tokens=*/token_index + 1,
1555  cache_capacity,
1556  hd,
1557  ad);
1558 
1559  // Output projection (Wo) into token-major buffer (decode-specialized).
1561  p->wo,
1562  p->bo,
1563  proj_row,
1564  D,
1565  aligned_D,
1566  H,
1567  ad);
1568 
1569  // Residual + LN2 / RMSNorm.
1570  ck_residual_add_token_major(input_row,
1571  proj_row,
1572  residual_row,
1573  /*tokens=*/1,
1574  aligned_D);
1575 
1576  rmsnorm_forward(residual_row,
1577  p->ln2_gamma,
1578  ln2_row,
1579  ln2_rstd,
1580  /*tokens=*/1,
1581  D,
1582  aligned_D,
1583  p->eps);
1584 
1585  // MLP block for this token (fully fused - all 3 projections in one pass).
1586  // Eliminates DRAM round-trip for swiglu intermediate values.
1588  p->w1,
1589  p->b1,
1590  p->w2,
1591  p->b2,
1592  mlp_row,
1593  aligned_D,
1594  aligned_intermediate);
1595 
1596  // Final residual.
1597  ck_residual_add_token_major(residual_row,
1598  mlp_row,
1599  out_row,
1600  /*tokens=*/1,
1601  aligned_D);
1602 }
1603 
1604 static void ck_qkv_project_head_major_token_q4_k(const float *input_row,
1605  const void *wq, const float *bq,
1606  const void *wk, const float *bk,
1607  const void *wv, const float *bv,
1608  float *q_token,
1609  float *k_token,
1610  float *v_token,
1611  int aligned_embed_dim,
1612  int num_heads,
1613  int num_kv_heads,
1614  int aligned_head_dim)
1615 {
1616  if (!input_row || !wq || !wk || !wv || !q_token || !k_token || !v_token) {
1617  return;
1618  }
1619 
1620  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1621  const size_t head_w_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1622 
1623  const uint8_t *wq_bytes = (const uint8_t *)wq;
1624  const uint8_t *wk_bytes = (const uint8_t *)wk;
1625  const uint8_t *wv_bytes = (const uint8_t *)wv;
1626 
1627  for (int h = 0; h < num_heads; ++h) {
1628  const void *wq_h = wq_bytes + (size_t)h * head_w_bytes;
1629  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
1630  float *q_h = q_token + (size_t)h * (size_t)aligned_head_dim;
1631  gemm_nt_q4_k(input_row, wq_h, bq_h, q_h,
1632  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim);
1633  }
1634 
1635  for (int h = 0; h < num_kv_heads; ++h) {
1636  const void *wk_h = wk_bytes + (size_t)h * head_w_bytes;
1637  const void *wv_h = wv_bytes + (size_t)h * head_w_bytes;
1638  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
1639  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
1640  float *k_h = k_token + (size_t)h * (size_t)aligned_head_dim;
1641  float *v_h = v_token + (size_t)h * (size_t)aligned_head_dim;
1642  gemm_nt_q4_k(input_row, wk_h, bk_h, k_h,
1643  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim);
1644  gemm_nt_q4_k(input_row, wv_h, bv_h, v_h,
1645  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim);
1646  }
1647 }
1648 
1650  const void *wq, const float *bq,
1651  const void *wk, const float *bk,
1652  const void *wv, const float *bv,
1653  float *q_token,
1654  float *k_token,
1655  float *v_token,
1656  int aligned_embed_dim,
1657  int num_heads,
1658  int num_kv_heads,
1659  int aligned_head_dim)
1660 {
1661  if (!input_q8 || !wq || !wk || !wv || !q_token || !k_token || !v_token) {
1662  return;
1663  }
1664 
1665  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1666  const size_t head_w_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1667 
1668  const uint8_t *wq_bytes = (const uint8_t *)wq;
1669  const uint8_t *wk_bytes = (const uint8_t *)wk;
1670  const uint8_t *wv_bytes = (const uint8_t *)wv;
1671 
1672  for (int h = 0; h < num_heads; ++h) {
1673  const void *wq_h = wq_bytes + (size_t)h * head_w_bytes;
1674  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
1675  float *q_h = q_token + (size_t)h * (size_t)aligned_head_dim;
1676  gemm_nt_q4_k_q8_k(input_q8, wq_h, bq_h, q_h,
1677  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim);
1678  }
1679 
1680  for (int h = 0; h < num_kv_heads; ++h) {
1681  const void *wk_h = wk_bytes + (size_t)h * head_w_bytes;
1682  const void *wv_h = wv_bytes + (size_t)h * head_w_bytes;
1683  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
1684  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
1685  float *k_h = k_token + (size_t)h * (size_t)aligned_head_dim;
1686  float *v_h = v_token + (size_t)h * (size_t)aligned_head_dim;
1687  gemm_nt_q4_k_q8_k(input_q8, wk_h, bk_h, k_h,
1688  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim);
1689  gemm_nt_q4_k_q8_k(input_q8, wv_h, bv_h, v_h,
1690  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim);
1691  }
1692 }
1693 
1694 static void ck_qkv_project_head_major_q4_k_q8_k(const float *input,
1695  const void *wq, const float *bq,
1696  const void *wk, const float *bk,
1697  const void *wv, const float *bv,
1698  float *q, float *k, float *v,
1699  int tokens,
1700  int kv_stride_tokens,
1701  int aligned_embed_dim,
1702  int num_heads,
1703  int num_kv_heads,
1704  int aligned_head_dim)
1705 {
1706  if (!input || !wq || !wk || !wv || !q || !k || !v) {
1707  return;
1708  }
1709  if (tokens <= 0 || aligned_embed_dim <= 0) {
1710  return;
1711  }
1712  if (kv_stride_tokens < tokens) {
1713  return;
1714  }
1715  if ((aligned_embed_dim % QK_K) != 0) {
1716  return;
1717  }
1718 
1719  const int q8_blocks = aligned_embed_dim / QK_K;
1720  block_q8_K q8_buf[q8_blocks];
1721  const size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
1722  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
1723 
1724  float q_token[num_heads * aligned_head_dim];
1725  float k_token[num_kv_heads * aligned_head_dim];
1726  float v_token[num_kv_heads * aligned_head_dim];
1727 
1728  for (int t = 0; t < tokens; ++t) {
1729  const float *input_row = input + (size_t)t * (size_t)aligned_embed_dim;
1730  quantize_row_q8_k(input_row, q8_buf, aligned_embed_dim);
1731 
1733  wq, bq,
1734  wk, bk,
1735  wv, bv,
1736  q_token,
1737  k_token,
1738  v_token,
1739  aligned_embed_dim,
1740  num_heads,
1741  num_kv_heads,
1742  aligned_head_dim);
1743 
1744  for (int h = 0; h < num_heads; ++h) {
1745  float *q_dst = q + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1746  memcpy(q_dst,
1747  q_token + (size_t)h * (size_t)aligned_head_dim,
1748  (size_t)aligned_head_dim * sizeof(float));
1749  }
1750 
1751  for (int h = 0; h < num_kv_heads; ++h) {
1752  float *k_dst = k + (size_t)h * kv_head_stride + (size_t)t * (size_t)aligned_head_dim;
1753  float *v_dst = v + (size_t)h * kv_head_stride + (size_t)t * (size_t)aligned_head_dim;
1754  memcpy(k_dst,
1755  k_token + (size_t)h * (size_t)aligned_head_dim,
1756  (size_t)aligned_head_dim * sizeof(float));
1757  memcpy(v_dst,
1758  v_token + (size_t)h * (size_t)aligned_head_dim,
1759  (size_t)aligned_head_dim * sizeof(float));
1760  }
1761  }
1762 }
1763 
1764 static void ck_attention_project_head_major_q4_k_q8_k(const float *attn_out,
1765  const void *wo,
1766  const float *bo,
1767  float *out,
1768  int tokens,
1769  int aligned_embed_dim,
1770  int num_heads,
1771  int aligned_head_dim)
1772 {
1773  if (!attn_out || !wo || !out) {
1774  return;
1775  }
1776  if (tokens <= 0 || aligned_embed_dim <= 0) {
1777  return;
1778  }
1779  if ((aligned_embed_dim % QK_K) != 0) {
1780  return;
1781  }
1782 
1783  const int K = num_heads * aligned_head_dim;
1784  if (K != aligned_embed_dim) {
1785  return;
1786  }
1787 
1788  const int q8_blocks = aligned_embed_dim / QK_K;
1789  block_q8_K q8_buf[q8_blocks];
1790  float attn_token[aligned_embed_dim];
1791  const size_t head_stride = (size_t)tokens * (size_t)aligned_head_dim;
1792 
1793  for (int t = 0; t < tokens; ++t) {
1794  for (int h = 0; h < num_heads; ++h) {
1795  const float *src = attn_out + (size_t)h * head_stride + (size_t)t * (size_t)aligned_head_dim;
1796  memcpy(attn_token + (size_t)h * (size_t)aligned_head_dim,
1797  src,
1798  (size_t)aligned_head_dim * sizeof(float));
1799  }
1800 
1801  quantize_row_q8_k(attn_token, q8_buf, aligned_embed_dim);
1802  gemm_nt_q4_k_q8_k(q8_buf, wo, bo,
1803  out + (size_t)t * (size_t)aligned_embed_dim,
1804  /*M=*/1, /*N=*/aligned_embed_dim, /*K=*/aligned_embed_dim);
1805  }
1806 }
1807 
1808 static void ck_mlp_swiglu_forward_q4_k_q8_k_prefill(const float *input,
1809  const void *w1,
1810  const float *b1,
1811  const void *w2,
1812  const float *b2,
1813  float *fc1_out,
1814  float *swiglu_out,
1815  float *output,
1816  int tokens,
1817  int aligned_embed_dim,
1818  int aligned_intermediate_dim)
1819 {
1820  if (!input || !w1 || !w2 || !fc1_out || !swiglu_out || !output) {
1821  return;
1822  }
1823  if (tokens <= 0) {
1824  return;
1825  }
1826  if ((aligned_embed_dim % QK_K) != 0 || (aligned_intermediate_dim % QK_K) != 0) {
1827  return;
1828  }
1829 
1830  const int up_dim = 2 * aligned_intermediate_dim;
1831  const int q8_blocks_embed = aligned_embed_dim / QK_K;
1832  const int q8_blocks_inter = aligned_intermediate_dim / QK_K;
1833  const int q8_blocks_max = (q8_blocks_embed > q8_blocks_inter) ? q8_blocks_embed : q8_blocks_inter;
1834  block_q8_K q8_buf[q8_blocks_max];
1835 
1836  for (int t = 0; t < tokens; ++t) {
1837  const float *input_row = input + (size_t)t * (size_t)aligned_embed_dim;
1838  float *fc1_row = fc1_out + (size_t)t * (size_t)up_dim;
1839 
1840  quantize_row_q8_k(input_row, q8_buf, aligned_embed_dim);
1841  gemm_nt_q4_k_q8_k(q8_buf, w1, b1, fc1_row,
1842  /*M=*/1, /*N=*/up_dim, /*K=*/aligned_embed_dim);
1843  }
1844 
1845  swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
1846 
1847  for (int t = 0; t < tokens; ++t) {
1848  const float *swiglu_row = swiglu_out + (size_t)t * (size_t)aligned_intermediate_dim;
1849  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
1850 
1851  quantize_row_q8_k(swiglu_row, q8_buf, aligned_intermediate_dim);
1852  gemm_nt_q4_k_q8_k(q8_buf, w2, b2, out_row,
1853  /*M=*/1, /*N=*/aligned_embed_dim, /*K=*/aligned_intermediate_dim);
1854  }
1855 }
1856 
1857 static void ck_qkv_project_head_major_token_quant(const float *input_row,
1858  const void *wq, const float *bq, CKDataType wq_dtype,
1859  const void *wk, const float *bk, CKDataType wk_dtype,
1860  const void *wv, const float *bv, CKDataType wv_dtype,
1861  float *q_token,
1862  float *k_token,
1863  float *v_token,
1864  int aligned_embed_dim,
1865  int num_heads,
1866  int num_kv_heads,
1867  int aligned_head_dim)
1868 {
1869  if (!input_row || !wq || !wk || !wv || !q_token || !k_token || !v_token) {
1870  return;
1871  }
1872 
1873  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1874  const size_t wq_head_bytes = ck_dtype_row_bytes(wq_dtype, head_w_elems);
1875  const size_t wk_head_bytes = ck_dtype_row_bytes(wk_dtype, head_w_elems);
1876  const size_t wv_head_bytes = ck_dtype_row_bytes(wv_dtype, head_w_elems);
1877 
1878  const uint8_t *wq_bytes = (const uint8_t *)wq;
1879  const uint8_t *wk_bytes = (const uint8_t *)wk;
1880  const uint8_t *wv_bytes = (const uint8_t *)wv;
1881 
1882  for (int h = 0; h < num_heads; ++h) {
1883  const void *wq_h = (wq_dtype == CK_DT_FP32)
1884  ? (const void *)((const float *)wq + (size_t)h * head_w_elems)
1885  : (const void *)(wq_bytes + (size_t)h * wq_head_bytes);
1886  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
1887  float *q_h = q_token + (size_t)h * (size_t)aligned_head_dim;
1888  ck_gemm_nt_quant(input_row, wq_h, bq_h, q_h,
1889  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim, wq_dtype);
1890  }
1891 
1892  for (int h = 0; h < num_kv_heads; ++h) {
1893  const void *wk_h = (wk_dtype == CK_DT_FP32)
1894  ? (const void *)((const float *)wk + (size_t)h * head_w_elems)
1895  : (const void *)(wk_bytes + (size_t)h * wk_head_bytes);
1896  const void *wv_h = (wv_dtype == CK_DT_FP32)
1897  ? (const void *)((const float *)wv + (size_t)h * head_w_elems)
1898  : (const void *)(wv_bytes + (size_t)h * wv_head_bytes);
1899  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
1900  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
1901  float *k_h = k_token + (size_t)h * (size_t)aligned_head_dim;
1902  float *v_h = v_token + (size_t)h * (size_t)aligned_head_dim;
1903  ck_gemm_nt_quant(input_row, wk_h, bk_h, k_h,
1904  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim, wk_dtype);
1905  ck_gemm_nt_quant(input_row, wv_h, bv_h, v_h,
1906  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim, wv_dtype);
1907  }
1908 }
1909 
1911 {
1912  if (!p) {
1913  return;
1914  }
1915 
1916  const int aligned_D = p->aligned_embed_dim;
1917  const int aligned_intermediate = p->aligned_intermediate_dim;
1918 
1920  p->ln1_gamma,
1921  p->ln1_out,
1922  p->ln1_rstd,
1923  p->tokens,
1924  p->embed_dim,
1925  aligned_D,
1926  p->eps);
1927 
1929  if ((aligned_D % QK_K) == 0 && (aligned_intermediate % QK_K) == 0) {
1931  p->wq, p->bq,
1932  p->wk, p->bk,
1933  p->wv, p->bv,
1934  p->q, p->k, p->v,
1935  p->tokens,
1936  p->tokens,
1937  aligned_D,
1938  p->num_heads,
1939  p->num_kv_heads,
1940  p->aligned_head_dim);
1941 
1942  if (p->rope_cos && p->rope_sin) {
1943  rope_forward_qk(p->q,
1944  p->k,
1945  p->rope_cos,
1946  p->rope_sin,
1947  p->num_heads,
1948  p->num_kv_heads,
1949  p->tokens,
1950  p->head_dim,
1951  p->aligned_head_dim,
1952  p->rope_pos_offset);
1953  }
1954 
1955  if (p->scores) {
1957  p->k,
1958  p->v,
1959  p->scores,
1960  p->attn_out,
1961  p->num_heads,
1962  p->num_kv_heads,
1963  p->tokens,
1964  p->head_dim,
1965  p->aligned_head_dim,
1967  } else {
1969  p->k,
1970  p->v,
1971  p->attn_out,
1972  p->num_heads,
1973  p->num_kv_heads,
1974  p->tokens,
1975  p->head_dim,
1976  p->aligned_head_dim);
1977  }
1978 
1980  p->wo,
1981  p->bo,
1982  p->proj_tmp,
1983  p->tokens,
1984  aligned_D,
1985  p->num_heads,
1986  p->aligned_head_dim);
1987 
1989  p->proj_tmp,
1990  p->residual1,
1991  p->tokens,
1992  aligned_D);
1993 
1995  p->ln2_gamma,
1996  p->ln2_out,
1997  p->ln2_rstd,
1998  p->tokens,
1999  p->embed_dim,
2000  aligned_D,
2001  p->eps);
2002 
2004  p->w1,
2005  p->b1,
2006  p->w2,
2007  p->b2,
2008  p->fc1_out,
2009  p->swiglu_out,
2010  p->mlp_out,
2011  p->tokens,
2012  aligned_D,
2013  aligned_intermediate);
2014 
2016  p->mlp_out,
2017  p->output,
2018  p->tokens,
2019  aligned_D);
2020  return;
2021  }
2022  }
2023 
2025  p->wq, p->bq,
2026  p->wk, p->bk,
2027  p->wv, p->bv,
2028  p->q, p->k, p->v,
2029  p->tokens,
2030  p->tokens,
2031  aligned_D,
2032  p->num_heads,
2033  p->num_kv_heads,
2034  p->aligned_head_dim);
2035 
2036  if (p->rope_cos && p->rope_sin) {
2037  rope_forward_qk(p->q,
2038  p->k,
2039  p->rope_cos,
2040  p->rope_sin,
2041  p->num_heads,
2042  p->num_kv_heads,
2043  p->tokens,
2044  p->head_dim,
2045  p->aligned_head_dim,
2046  p->rope_pos_offset);
2047  }
2048 
2049  if (p->scores) {
2051  p->k,
2052  p->v,
2053  p->scores,
2054  p->attn_out,
2055  p->num_heads,
2056  p->num_kv_heads,
2057  p->tokens,
2058  p->head_dim,
2059  p->aligned_head_dim,
2061  } else {
2063  p->k,
2064  p->v,
2065  p->attn_out,
2066  p->num_heads,
2067  p->num_kv_heads,
2068  p->tokens,
2069  p->head_dim,
2070  p->aligned_head_dim);
2071  }
2072 
2074  p->wo,
2075  p->bo,
2076  p->proj_tmp,
2077  p->proj_scratch,
2078  p->tokens,
2079  p->aligned_embed_dim,
2080  p->num_heads,
2081  p->aligned_head_dim);
2082 
2084  p->proj_tmp,
2085  p->residual1,
2086  p->tokens,
2087  p->aligned_embed_dim);
2088 
2090  p->ln2_gamma,
2091  p->ln2_out,
2092  p->ln2_rstd,
2093  p->tokens,
2094  p->embed_dim,
2095  p->aligned_embed_dim,
2096  p->eps);
2097 
2099  p->w1,
2100  p->b1,
2101  p->w2,
2102  p->b2,
2103  p->fc1_out,
2104  p->swiglu_out,
2105  p->mlp_out,
2106  p->tokens,
2107  p->aligned_embed_dim,
2109 
2111  p->mlp_out,
2112  p->output,
2113  p->tokens,
2114  p->aligned_embed_dim);
2115 }
2116 
2118  int token_index,
2119  int cache_capacity)
2120 {
2121  if (!p) {
2122  return;
2123  }
2124  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->ln1_out || !p->ln2_out ||
2125  !p->wq || !p->wk || !p->wv || !p->wo || !p->w1 || !p->w2 ||
2126  !p->k || !p->v ||
2127  !p->proj_tmp || !p->residual1 || !p->fc1_out || !p->swiglu_out || !p->mlp_out || !p->output) {
2128  return;
2129  }
2130  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
2131  return;
2132  }
2133 
2134  const int D = p->embed_dim;
2135  const int aligned_D = p->aligned_embed_dim;
2136  const int H = p->num_heads;
2137  const int H_kv = p->num_kv_heads;
2138  const int hd = p->head_dim;
2139  const int ad = p->aligned_head_dim;
2140  const int aligned_intermediate = p->aligned_intermediate_dim;
2141  const int K_concat = H * ad;
2142 
2143  /* Decode buffers are single-token; token_index only applies to KV cache. */
2144  const size_t token_slot = 0;
2145  const float *input_row = p->input + token_slot * (size_t)aligned_D;
2146  float *ln1_row = p->ln1_out + token_slot * (size_t)aligned_D;
2147  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
2148  float *proj_row = p->proj_tmp + token_slot * (size_t)aligned_D;
2149  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
2150  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
2151  float *out_row = p->output + token_slot * (size_t)aligned_D;
2152 
2153  float ln1_rstd_tmp = 0.0f;
2154  float ln2_rstd_tmp = 0.0f;
2155  float *ln1_rstd = p->ln1_rstd ? (p->ln1_rstd + token_slot) : &ln1_rstd_tmp;
2156  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
2157 
2158  /* Scratch for a single token in head-major layout: [head, aligned_head_dim]. */
2159  size_t q_elems = (size_t)H * (size_t)ad;
2160  size_t kv_elems = (size_t)H_kv * (size_t)ad;
2161  float q_token[q_elems];
2162  float k_token[kv_elems];
2163  float v_token[kv_elems];
2164  float attn_token[q_elems];
2165 
2166  /* LN1 / RMSNorm. */
2167  ck_debug_check_buffer("input_row", input_row, aligned_D);
2168  rmsnorm_forward(input_row,
2169  p->ln1_gamma,
2170  ln1_row,
2171  ln1_rstd,
2172  /*tokens=*/1,
2173  D,
2174  aligned_D,
2175  p->eps);
2176  ck_debug_check_buffer("ln1_out (after rmsnorm)", ln1_row, aligned_D);
2177 
2179  if ((aligned_D % QK_K) == 0 && (aligned_intermediate % QK_K) == 0) {
2180  const int q8_blocks_embed = aligned_D / QK_K;
2181  const int q8_blocks_inter = aligned_intermediate / QK_K;
2182  const int q8_blocks_max = (q8_blocks_embed > q8_blocks_inter) ? q8_blocks_embed : q8_blocks_inter;
2183  block_q8_K q8_buf[q8_blocks_max];
2184 
2185  /* Project Q/K/V with Q8_K activations. */
2186  quantize_row_q8_k(ln1_row, q8_buf, aligned_D);
2187  ck_debug_check_q8k("q8_buf (after quantize)", q8_buf, q8_blocks_embed);
2188  ck_debug_check_q4k_weights("wq weights", p->wq, (aligned_D / QK_K) * (H * ad));
2190  p->wq, p->bq,
2191  p->wk, p->bk,
2192  p->wv, p->bv,
2193  q_token, k_token, v_token,
2194  aligned_D,
2195  H,
2196  H_kv,
2197  ad);
2198  ck_debug_check_buffer("q_token (after QKV proj)", q_token, (int)q_elems);
2199  ck_debug_check_buffer("k_token (after QKV proj)", k_token, (int)kv_elems);
2200  ck_debug_check_buffer("v_token (after QKV proj)", v_token, (int)kv_elems);
2201 
2202  /* RoPE for the new token at absolute position `p->rope_pos_offset`. */
2203  if (p->rope_cos && p->rope_sin) {
2204  rope_forward_qk(q_token,
2205  k_token,
2206  p->rope_cos,
2207  p->rope_sin,
2208  H,
2209  H_kv,
2210  /*num_tokens=*/1,
2211  hd,
2212  ad,
2213  p->rope_pos_offset);
2214  }
2215 
2216  /* Update KV cache. */
2217  kv_cache_write_head_major(k_token,
2218  v_token,
2219  p->k,
2220  p->v,
2221  H_kv,
2222  token_index,
2223  cache_capacity,
2224  hd,
2225  ad);
2226 
2227  /* Decode attention for this token using the KV cache. */
2229  p->k,
2230  p->v,
2231  attn_token,
2232  H,
2233  H_kv,
2234  /*kv_tokens=*/token_index + 1,
2235  cache_capacity,
2236  hd,
2237  ad);
2238  ck_debug_check_buffer("attn_token (after attention)", attn_token, (int)q_elems);
2239 
2240  /* Quantized output projection (Wo) with Q8_K activations. */
2241  quantize_row_q8_k(attn_token, q8_buf, aligned_D);
2242  gemm_nt_q4_k_q8_k(q8_buf,
2243  p->wo,
2244  p->bo,
2245  proj_row,
2246  /*M=*/1,
2247  aligned_D,
2248  /*K=*/K_concat);
2249  ck_debug_check_buffer("proj_row (after Wo proj)", proj_row, aligned_D);
2250 
2251  for (int j = D; j < aligned_D; ++j) {
2252  proj_row[j] = 0.0f;
2253  }
2254 
2255  /* Residual + LN2 / RMSNorm. */
2256  ck_residual_add_token_major(input_row,
2257  proj_row,
2258  residual_row,
2259  /*tokens=*/1,
2260  aligned_D);
2261 
2262  rmsnorm_forward(residual_row,
2263  p->ln2_gamma,
2264  ln2_row,
2265  ln2_rstd,
2266  /*tokens=*/1,
2267  D,
2268  aligned_D,
2269  p->eps);
2270 
2271  /* MLP block for this token (Q8_K activations). */
2272  int up_dim = 2 * aligned_intermediate;
2273  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
2274  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
2275 
2277  p->w1,
2278  p->b1,
2279  p->w2,
2280  p->b2,
2281  fc1_row,
2282  swiglu_row,
2283  mlp_row,
2284  aligned_D,
2285  aligned_intermediate);
2286  ck_debug_check_buffer("mlp_row (after MLP)", mlp_row, aligned_D);
2287 
2288  /* Final residual. */
2289  ck_residual_add_token_major(residual_row,
2290  mlp_row,
2291  out_row,
2292  /*tokens=*/1,
2293  aligned_D);
2294  ck_debug_check_buffer("out_row (final output)", out_row, aligned_D);
2295  return;
2296  }
2297  }
2298 
2299  /* Project Q/K/V for the new token (Q4_K weights). */
2301  p->wq, p->bq,
2302  p->wk, p->bk,
2303  p->wv, p->bv,
2304  q_token, k_token, v_token,
2305  aligned_D,
2306  H,
2307  H_kv,
2308  ad);
2309 
2310  /* RoPE for the new token at absolute position `p->rope_pos_offset`. */
2311  if (p->rope_cos && p->rope_sin) {
2312  rope_forward_qk(q_token,
2313  k_token,
2314  p->rope_cos,
2315  p->rope_sin,
2316  H,
2317  H_kv,
2318  /*num_tokens=*/1,
2319  hd,
2320  ad,
2321  p->rope_pos_offset);
2322  }
2323 
2324  /* Update KV cache. */
2325  kv_cache_write_head_major(k_token,
2326  v_token,
2327  p->k,
2328  p->v,
2329  H_kv,
2330  token_index,
2331  cache_capacity,
2332  hd,
2333  ad);
2334 
2335  /* Decode attention for this token using the KV cache. */
2337  p->k,
2338  p->v,
2339  attn_token,
2340  H,
2341  H_kv,
2342  /*kv_tokens=*/token_index + 1,
2343  cache_capacity,
2344  hd,
2345  ad);
2346 
2347  /* Quantized output projection: Wo is stored as a flattened Q4_K matrix. */
2348  gemm_nt_q4_k(attn_token,
2349  p->wo,
2350  p->bo,
2351  proj_row,
2352  /*M=*/1,
2353  aligned_D,
2354  /*K=*/K_concat);
2355 
2356  for (int j = D; j < aligned_D; ++j) {
2357  proj_row[j] = 0.0f;
2358  }
2359 
2360  /* Residual + LN2 / RMSNorm. */
2361  ck_residual_add_token_major(input_row,
2362  proj_row,
2363  residual_row,
2364  /*tokens=*/1,
2365  aligned_D);
2366 
2367  rmsnorm_forward(residual_row,
2368  p->ln2_gamma,
2369  ln2_row,
2370  ln2_rstd,
2371  /*tokens=*/1,
2372  D,
2373  aligned_D,
2374  p->eps);
2375 
2376  /* MLP block for this token. */
2377  int up_dim = 2 * aligned_intermediate;
2378  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
2379  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
2380 
2382  p->w1,
2383  p->b1,
2384  p->w2,
2385  p->b2,
2386  fc1_row,
2387  swiglu_row,
2388  mlp_row,
2389  /*tokens=*/1,
2390  aligned_D,
2391  aligned_intermediate);
2392 
2393  /* Final residual. */
2394  ck_residual_add_token_major(residual_row,
2395  mlp_row,
2396  out_row,
2397  /*tokens=*/1,
2398  aligned_D);
2399 }
2400 
2402 {
2403  if (!p) {
2404  return;
2405  }
2406 
2408  p->ln1_gamma,
2409  p->ln1_out,
2410  p->ln1_rstd,
2411  p->tokens,
2412  p->embed_dim,
2413  p->aligned_embed_dim,
2414  p->eps);
2415 
2417  p->wq, p->bq, p->wq_dtype,
2418  p->wk, p->bk, p->wk_dtype,
2419  p->wv, p->bv, p->wv_dtype,
2420  p->q, p->k, p->v,
2421  p->tokens,
2422  p->tokens,
2423  p->aligned_embed_dim,
2424  p->num_heads,
2425  p->num_kv_heads,
2426  p->aligned_head_dim);
2427 
2428  if (p->rope_cos && p->rope_sin) {
2429  rope_forward_qk(p->q,
2430  p->k,
2431  p->rope_cos,
2432  p->rope_sin,
2433  p->num_heads,
2434  p->num_kv_heads,
2435  p->tokens,
2436  p->head_dim,
2437  p->aligned_head_dim,
2438  p->rope_pos_offset);
2439  }
2440 
2441  if (p->scores) {
2443  p->k,
2444  p->v,
2445  p->scores,
2446  p->attn_out,
2447  p->num_heads,
2448  p->num_kv_heads,
2449  p->tokens,
2450  p->head_dim,
2451  p->aligned_head_dim,
2453  } else {
2455  p->k,
2456  p->v,
2457  p->attn_out,
2458  p->num_heads,
2459  p->num_kv_heads,
2460  p->tokens,
2461  p->head_dim,
2462  p->aligned_head_dim);
2463  }
2464 
2466  p->wo,
2467  p->bo,
2468  p->proj_tmp,
2469  p->proj_scratch,
2470  p->tokens,
2471  p->aligned_embed_dim,
2472  p->num_heads,
2473  p->aligned_head_dim,
2474  p->wo_dtype);
2475 
2477  p->proj_tmp,
2478  p->residual1,
2479  p->tokens,
2480  p->aligned_embed_dim);
2481 
2483  p->ln2_gamma,
2484  p->ln2_out,
2485  p->ln2_rstd,
2486  p->tokens,
2487  p->embed_dim,
2488  p->aligned_embed_dim,
2489  p->eps);
2490 
2492  p->w1,
2493  p->b1,
2494  p->w1_dtype,
2495  p->w2,
2496  p->b2,
2497  p->w2_dtype,
2498  p->fc1_out,
2499  p->swiglu_out,
2500  p->mlp_out,
2501  p->tokens,
2502  p->aligned_embed_dim,
2504 
2506  p->mlp_out,
2507  p->output,
2508  p->tokens,
2509  p->aligned_embed_dim);
2510 }
2511 
2513  int token_index,
2514  int cache_capacity)
2515 {
2516  if (!p) {
2517  return;
2518  }
2519  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->ln1_out || !p->ln2_out ||
2520  !p->wq || !p->wk || !p->wv || !p->wo || !p->w1 || !p->w2 ||
2521  !p->k || !p->v ||
2522  !p->proj_tmp || !p->proj_scratch || !p->residual1 || !p->fc1_out || !p->swiglu_out || !p->mlp_out || !p->output) {
2523  return;
2524  }
2525  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
2526  return;
2527  }
2528 
2529  const int D = p->embed_dim;
2530  const int aligned_D = p->aligned_embed_dim;
2531  const int H = p->num_heads;
2532  const int H_kv = p->num_kv_heads;
2533  const int hd = p->head_dim;
2534  const int ad = p->aligned_head_dim;
2535  const int aligned_intermediate = p->aligned_intermediate_dim;
2536  const int K_concat = H * ad;
2537 
2538  /* Decode buffers are single-token; token_index only applies to KV cache. */
2539  const size_t token_slot = 0;
2540  const float *input_row = p->input + token_slot * (size_t)aligned_D;
2541  float *ln1_row = p->ln1_out + token_slot * (size_t)aligned_D;
2542  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
2543  float *proj_row = p->proj_tmp + token_slot * (size_t)aligned_D;
2544  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
2545  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
2546  float *out_row = p->output + token_slot * (size_t)aligned_D;
2547 
2548  float ln1_rstd_tmp = 0.0f;
2549  float ln2_rstd_tmp = 0.0f;
2550  float *ln1_rstd = p->ln1_rstd ? (p->ln1_rstd + token_slot) : &ln1_rstd_tmp;
2551  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
2552 
2553  size_t q_elems = (size_t)H * (size_t)ad;
2554  size_t kv_elems = (size_t)H_kv * (size_t)ad;
2555  float q_token[q_elems];
2556  float k_token[kv_elems];
2557  float v_token[kv_elems];
2558  float attn_token[q_elems];
2559 
2560  rmsnorm_forward(input_row,
2561  p->ln1_gamma,
2562  ln1_row,
2563  ln1_rstd,
2564  /*tokens=*/1,
2565  D,
2566  aligned_D,
2567  p->eps);
2568 
2570  p->wq, p->bq, p->wq_dtype,
2571  p->wk, p->bk, p->wk_dtype,
2572  p->wv, p->bv, p->wv_dtype,
2573  q_token, k_token, v_token,
2574  aligned_D,
2575  H,
2576  H_kv,
2577  ad);
2578 
2579  if (p->rope_cos && p->rope_sin) {
2580  rope_forward_qk(q_token,
2581  k_token,
2582  p->rope_cos,
2583  p->rope_sin,
2584  H,
2585  H_kv,
2586  /*num_tokens=*/1,
2587  hd,
2588  ad,
2589  p->rope_pos_offset);
2590  }
2591 
2592  kv_cache_write_head_major(k_token,
2593  v_token,
2594  p->k,
2595  p->v,
2596  H_kv,
2597  token_index,
2598  cache_capacity,
2599  hd,
2600  ad);
2601 
2603  p->k,
2604  p->v,
2605  attn_token,
2606  H,
2607  H_kv,
2608  /*kv_tokens=*/token_index + 1,
2609  cache_capacity,
2610  hd,
2611  ad);
2612 
2613  if (p->wo_dtype == CK_DT_FP32) {
2615  (const float *)p->wo,
2616  p->bo,
2617  proj_row,
2618  D,
2619  aligned_D,
2620  H,
2621  ad);
2622  } else {
2623  /* Quantized attention output projection - handle all quant types */
2624  ck_gemm_nt_quant(attn_token,
2625  p->wo,
2626  p->bo,
2627  proj_row,
2628  /*M=*/1,
2629  aligned_D,
2630  /*K=*/K_concat,
2631  p->wo_dtype);
2632  for (int j = D; j < aligned_D; ++j) {
2633  proj_row[j] = 0.0f;
2634  }
2635  }
2636 
2637  ck_residual_add_token_major(input_row,
2638  proj_row,
2639  residual_row,
2640  /*tokens=*/1,
2641  aligned_D);
2642 
2643  rmsnorm_forward(residual_row,
2644  p->ln2_gamma,
2645  ln2_row,
2646  ln2_rstd,
2647  /*tokens=*/1,
2648  D,
2649  aligned_D,
2650  p->eps);
2651 
2652  int up_dim = 2 * aligned_intermediate;
2653  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
2654  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
2655 
2657  p->w1,
2658  p->b1,
2659  p->w1_dtype,
2660  p->w2,
2661  p->b2,
2662  p->w2_dtype,
2663  fc1_row,
2664  swiglu_row,
2665  mlp_row,
2666  /*tokens=*/1,
2667  aligned_D,
2668  aligned_intermediate);
2669 
2670  ck_residual_add_token_major(residual_row,
2671  mlp_row,
2672  out_row,
2673  /*tokens=*/1,
2674  aligned_D);
2675 }
2676 
2678 {
2679  if (!p) {
2680  return;
2681  }
2682 
2683  int T = p->tokens;
2684  int aligned_embed = p->aligned_embed_dim;
2685  int aligned_head = p->aligned_head_dim;
2686  int aligned_intermediate = p->aligned_intermediate_dim;
2687  int up_dim = 2 * aligned_intermediate;
2688  int num_threads = 1;
2689 
2690  // 1) Residual add (output = residual1 + mlp_out)
2691  ck_residual_add_backward(p->d_output, p->d_residual1, p->d_mlp_out, T, aligned_embed);
2692 
2693  // 2) MLP down proj backward
2695  p->swiglu_out,
2696  p->w2,
2697  p->d_swiglu_out,
2698  p->d_w2,
2699  p->d_b2,
2700  T,
2701  aligned_intermediate,
2702  aligned_embed,
2703  num_threads);
2704 
2705  // 3) SwiGLU backward
2706  swiglu_backward(p->fc1_out, p->d_swiglu_out, p->d_fc1_out, T, aligned_intermediate);
2707 
2708  // 4) MLP up proj backward
2710  p->ln2_out,
2711  p->w1,
2712  p->d_ln2_out,
2713  p->d_w1,
2714  p->d_b1,
2715  T,
2716  aligned_embed,
2717  up_dim,
2718  num_threads);
2719 
2720  // 5) RMSNorm (ln2) backward; reuse d_output as scratch for d_residual1_from_ln2
2722  p->residual1,
2723  p->ln2_gamma,
2724  p->ln2_rstd,
2725  p->d_output,
2726  p->d_ln2_gamma,
2727  T,
2728  p->embed_dim,
2729  aligned_embed);
2730  ck_add_inplace(p->d_residual1, p->d_output, T, aligned_embed);
2731 
2732  // 6) Residual add (residual1 = input + proj_tmp)
2733  ck_residual_add_backward(p->d_residual1, p->d_input, p->d_proj_tmp, T, aligned_embed);
2734 
2735  // 7) Attention projection backward
2737  p->attn_out,
2738  p->wo,
2739  p->d_attn_out,
2740  p->d_wo,
2741  p->d_bo,
2742  T,
2743  aligned_embed,
2744  p->num_heads,
2745  aligned_head);
2746 
2747  // 8) Attention backward
2749  p->q,
2750  p->k,
2751  p->v,
2752  p->scores,
2753  p->d_q,
2754  p->d_k,
2755  p->d_v,
2756  p->d_scores,
2757  p->num_heads,
2758  p->num_kv_heads,
2759  T,
2760  p->head_dim,
2761  aligned_head,
2763 
2764  // 9) RoPE backward (if enabled)
2765  if (p->rope_cos && p->rope_sin) {
2766  rope_backward_qk(p->d_q,
2767  p->d_k,
2768  p->d_q,
2769  p->d_k,
2770  p->rope_cos,
2771  p->rope_sin,
2772  p->num_heads,
2773  p->num_kv_heads,
2774  T,
2775  p->head_dim,
2776  aligned_head,
2777  p->rope_pos_offset);
2778  }
2779 
2780  // 10) QKV projection backward (scratch uses d_proj_tmp)
2782  p->d_k,
2783  p->d_v,
2784  p->ln1_out,
2785  p->wq,
2786  p->bq,
2787  p->wk,
2788  p->bk,
2789  p->wv,
2790  p->bv,
2791  p->d_ln1_out,
2792  p->d_wq,
2793  p->d_bq,
2794  p->d_wk,
2795  p->d_bk,
2796  p->d_wv,
2797  p->d_bv,
2798  p->d_proj_tmp,
2799  T,
2800  aligned_embed,
2801  p->num_heads,
2802  p->num_kv_heads,
2803  aligned_head,
2804  num_threads);
2805 
2806  // 11) RMSNorm (ln1) backward; reuse d_ln1_out as scratch for d_input_from_ln1
2808  p->input,
2809  p->ln1_gamma,
2810  p->ln1_rstd,
2811  p->d_ln1_out,
2812  p->d_ln1_gamma,
2813  T,
2814  p->embed_dim,
2815  aligned_embed);
2816  ck_add_inplace(p->d_input, p->d_ln1_out, T, aligned_embed);
2817 }
CKDataType
Supported data types in C-Kernel-Engine.
Definition: ckernel_dtype.h:27
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
@ CK_DT_Q4_0
Definition: ckernel_dtype.h:38
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_Q5_0
Definition: ckernel_dtype.h:44
@ CK_DT_FP32
Definition: ckernel_dtype.h:29
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
@ CK_DT_Q4_1
Definition: ckernel_dtype.h:39
@ CK_DT_Q5_1
Definition: ckernel_dtype.h:45
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void gemm_nt_q4_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void gemm_naive_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:125
void swiglu_forward(const float *input, float *output, int tokens, int dim)
void gemm_swiglu_fused(const float *x, const float *W_gate, const float *W_up, const float *b_gate, const float *b_up, float *output, int M, int N, int K)
void swiglu_backward(const float *input, const float *d_output, float *d_input, int tokens, int dim)
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void attention_forward_causal_head_major_gqa(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void rope_backward_qk(const float *d_q_out, const float *d_k_out, float *d_q, float *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:497
void gemm_nt_q4_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q4_1 weights: C = A @ B^T.
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
void gemm_nt_q5_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void fc1_backward_kernel(const float *d_output, const float *fc1_input, const float *W_fc1, float *d_input, float *d_W_fc1, float *d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads)
Definition: mlp_kernels.c:167
void gemm_nt_q6_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q8_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void fc2_backward_kernel(const float *d_output, const float *fc2_input, const float *W_fc2, float *d_input, float *d_W_fc2, float *d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads)
Definition: mlp_kernels.c:118
void quantize_row_q8_k(const float *x, void *y, int k)
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void gemm_nt_q5_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_1 weights: C = A @ B^T.
void attention_forward_causal_head_major_gqa_flash(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:448
void fused_mlp_swiglu_decode_v2(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
void attention_backward_causal_head_major_gqa(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
int ck_strict_parity_enabled(void)
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:661
static void ck_add_inplace(float *dst, const float *src, int tokens, int aligned_embed_dim)
static void ck_qkv_project_head_major_quant(const float *input, const void *wq, const float *bq, CKDataType wq_dtype, const void *wk, const float *bk, CKDataType wk_dtype, const void *wv, const float *bv, CKDataType wv_dtype, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_attention_project_head_major_quant(const float *attn_out, const void *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, CKDataType wo_dtype)
void ck_layer_backward_rmsnorm_swiglu(const CKLayerBackwardParams *p)
void ck_mlp_swiglu_forward_fused_token(const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *swiglu_row, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_attention_project_head_major(const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_mlp_swiglu_forward(const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_attention_project_head_major_ref(const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_residual_add_backward(const float *d_out, float *d_a, float *d_b, int tokens, int aligned_embed_dim)
static void ck_mlp_swiglu_forward_quant(const float *input, const void *w1, const float *b1, CKDataType w1_dtype, const void *w2, const float *b2, CKDataType w2_dtype, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_qkv_project_head_major_token_q4_k_q8_k(const block_q8_K *input_q8, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_attention_project_head_major_q4_k(const float *attn_out, const void *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void ck_qkv_project_head_major_q4_k_q8_k(const float *input, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_qkv_project_head_major_ref(const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static int ck_q8k_activations_enabled(void)
void ck_layer_forward_rmsnorm_swiglu_quant(const CKLayerForwardParamsQ4K *p)
static int ck_layer_debug_enabled(void)
void ck_mlp_swiglu_forward_fully_fused_token(const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_mlp_swiglu_forward_q4_k_q8_k(const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_layer_forward_rmsnorm_swiglu_decode(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_layer_forward_rmsnorm_swiglu_q4_k(const CKLayerForwardParamsQ4K *p)
void ck_qkv_project_head_major(const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_debug_check_q8k(const char *stage, const void *q8_buf, int num_blocks)
static void ck_attention_project_head_major_q4_k_q8_k(const float *attn_out, const void *wo, const float *bo, float *out, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void ck_mlp_swiglu_forward_q4_k(const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_gemm_nt_quant(const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype)
void ck_attention_flash_decode_wrapper(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
Wrapper to call TRUE flash attention from orchestration layer.
void ck_qkv_project_head_major_backward(const float *d_q, const float *d_k, const float *d_v, const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *d_input, float *d_wq, float *d_bq, float *d_wk, float *d_bk, float *d_wv, float *d_bv, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim, int num_threads)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
void ck_layer_forward_rmsnorm_swiglu(const CKLayerForwardParams *p)
void ck_layer_forward_rmsnorm_swiglu_decode_fused(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_layer_forward_rmsnorm_swiglu_decode_q4_k(const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity)
static void ck_debug_check_q4k_weights(const char *stage, const void *q4_buf, int num_blocks)
static void ck_debug_check_buffer(const char *stage, const float *buf, int size)
static void ck_mlp_swiglu_forward_ref(const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_layer_forward_rmsnorm_swiglu_ref(const CKLayerForwardParams *p)
static void ck_qkv_project_head_major_token_q4_k(const float *input_row, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_qkv_project_head_major_q4_k(const float *input, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_mlp_swiglu_forward_q4_k_q8_k_prefill(const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_qkv_project_head_major_token_quant(const float *input_row, const void *wq, const float *bq, CKDataType wq_dtype, const void *wk, const float *bk, CKDataType wk_dtype, const void *wv, const float *bv, CKDataType wv_dtype, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_attention_project_head_major_backward(const float *d_out, const float *attn_out, const float *wo, float *d_attn_out, float *d_wo, float *d_bo, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_layer_forward_rmsnorm_swiglu_decode_quant(const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity)
void ck_qkv_project_head_major_token(const float *input_row, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_attention_project_head_major_decode_token(const float *attn_token, const float *wo, const float *bo, float *out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
#define QK_K
#define C(color)
Definition: show_config.c:39