← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_engine.h
Go to the documentation of this file.
1 #ifndef CKERNEL_ENGINE_H
2 #define CKERNEL_ENGINE_H
3 
4 #include <stddef.h>
5 #include <stdint.h>
6 #include "cpu_features.h"
7 #include "ckernel_quant.h" /* INT8 block types (block_q8_0, block_q8_K, etc.) */
8 #include "mega_fused_attention.h"
9 
10 #ifdef __cplusplus
11 extern "C" {
12 #endif
13 
14 /**
15  * Core math backend interface for C-Kernel-Engine.
16  *
17  * This is intentionally minimal and matches the conventions already used
18  * in C-Transformer for GEMM kernels.
19  *
20  * Layout assumptions (LLM-style shapes):
21  * - A: [M x K], row-major, A(i,k) = A[i*K + k]
22  * - B: [N x K], row-major, B(j,k) = B[j*K + k]
23  * - C: [M x N], row-major, C(i,j) = C[i*N + j]
24  * - bias: optional [N], added per output column j
25  */
26 typedef struct {
27  void (*sgemm)(int M, int N, int K,
28  const float *A, int lda,
29  const float *B, int ldb,
30  const float *bias,
31  float *C, int ldc);
33 
34 /**
35  * Obtain the built-in native backend (single-node CPU, C + intrinsics).
36  */
38 
39 // Enable stricter numeric parity (single-thread + double-accumulation GEMM).
40 void ck_set_strict_parity(int enabled);
41 int ck_strict_parity_enabled(void);
42 
43 // Thread configuration - call once at startup
44 // num_threads: 0 = auto-detect physical cores, >0 = use specified count
45 void ck_set_num_threads(int num_threads);
46 int ck_get_num_threads(void);
47 int ck_get_physical_cores(void);
48 
49 // Expose the individual GEMM kernels copied from C-Transformer.
50 void gemm_naive_parallel(const float *A,
51  const float *B,
52  const float *bias,
53  float *C,
54  int M, int N, int K);
55 
56 void gemm_avx512_parallel(const float *A,
57  const float *B,
58  const float *bias,
59  float *C,
60  int M, int N, int K);
61 
62 void gemm_fine_grained_parallel(const float *A,
63  const float *B,
64  const float *bias,
65  float *C,
66  int M, int N, int K);
67 
68  void gemm_blocked_serial(const float *A,
69  const float *B,
70  const float *bias,
71  float *C,
72  int M, int N, int K);
73 
74  // Reference BF16 GEMM (A/B/bias in BF16, output BF16).
75 void gemm_blocked_serial_bf16(const uint16_t *A,
76  const uint16_t *B,
77  const uint16_t *bias,
78  uint16_t *C,
79  int M, int N, int K);
80 
81 // =============================================================================
82 // Quantized (GGML-style) GEMM/GEMV helpers
83 // =============================================================================
84 //
85 // These kernels are used for weight-only quantized inference (e.g. Q4_K_M).
86 // The "NT" wrapper matches the engine's common layout:
87 // A: [M x K] fp32 (token-major)
88 // B: [N x K] quantized (row-major by output channel)
89 // C: [M x N] fp32
90 //
91 // NOTE: Q4_K requires K to be a multiple of 256 (QK_K).
92 
93 void gemv_q4_k(float *y,
94  const void *W,
95  const float *x,
96  int M, int K);
97 
98 void gemm_q4_k(float *Y,
99  const void *W,
100  const float *X,
101  int M, int N, int K);
102 
103 void gemm_nt_q4_k(const float *A,
104  const void *B,
105  const float *bias,
106  float *C,
107  int M, int N, int K);
108 
109 void dequant_q4_k_row(const void *src, float *dst, size_t n_elements);
110 
111 void gemv_q6_k(float *y,
112  const void *W,
113  const float *x,
114  int M, int K);
115 
116 void gemm_q6_k(float *Y,
117  const void *W,
118  const float *X,
119  int M, int N, int K);
120 
121 void gemm_nt_q6_k(const float *A,
122  const void *B,
123  const float *bias,
124  float *C,
125  int M, int N, int K);
126 
127 // Simple quant GEMM (Q4_0, Q4_1, Q5_0, Q5_1, Q8_0)
128 void gemm_nt_q4_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K);
129 void gemm_nt_q4_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K);
130 void gemm_nt_q5_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K);
131 void gemm_nt_q5_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K);
132 void gemm_nt_q5_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K);
133 void gemm_nt_q8_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K);
134 
135 // GEMV versions (for decode mode - single token)
136 void gemv_q4_0(float *y, const void *W, const float *x, int M, int K);
137 void gemv_q5_0(float *y, const void *W, const float *x, int M, int K);
138 void gemv_q5_1(float *y, const void *W, const float *x, int M, int K);
139 void gemv_q5_k(float *y, const void *W, const float *x, int M, int K);
140 void gemv_q8_0(float *y, const void *W, const float *x, int M, int K);
141 
142 /* Parallel Q5_0 versions - caller provides ith/nth from OpenMP region */
143 void gemv_q5_0_parallel(float *y, const void *W, const float *x,
144  int M, int K, int ith, int nth);
145 void gemv_q5_0_parallel_simd(float *y, const void *W, const float *x,
146  int M, int K, int ith, int nth);
147 
148 void dequant_q6_k_row(const void *src, float *dst, size_t n_elements);
149 
150 // Simple quant dequantization (Q4_0, Q4_1, Q5_0, Q5_1, Q8_0)
151 void dequant_q4_0_row(const void *src, float *dst, size_t n_elements);
152 void dequant_q4_1_row(const void *src, float *dst, size_t n_elements);
153 void dequant_q5_0_row(const void *src, float *dst, size_t n_elements);
154 void dequant_q5_1_row(const void *src, float *dst, size_t n_elements);
155 void dequant_q8_0_row(const void *src, float *dst, size_t n_elements);
156 
157 // ============================================================================
158 // INT8 ACTIVATION KERNELS
159 // ============================================================================
160 
161 // Q8_0 quantization (32 elements per block, 34 bytes: 2-byte scale + 32 int8)
162 void quantize_row_q8_0(const float *x, void *y, int k);
163 
164 // Batch Q8_0 quantization (row-major output for GEMM compatibility)
165 // Output: each row at offset row * ((k/32) * 34) bytes
166 void quantize_batch_q8_0(const float *x, void *y, int num_rows, int k);
167 
168 // Q5_0 weights x Q8_0 activations
169 void gemv_q5_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K);
170 
171 // Q8_0 weights x Q8_0 activations
172 void gemv_q8_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K);
173 
174 // Fused GEMV: quantize(FP32->Q8_0) + GEMV(Q5_0 weights) + bias add
175 void gemv_fused_q5_0_bias_dispatch(float *y, const void *W, const float *x,
176  const float *bias, int M, int K);
177 
178 // Fused GEMV: quantize(FP32->Q8_0) + GEMV(Q8_0 weights) + bias add
179 void gemv_fused_q8_0_bias_dispatch(float *y, const void *W, const float *x,
180  const float *bias, int M, int K);
181 
182 // Q8_K quantization (256 elements per block, super-block format)
183 void quantize_row_q8_k(const float *x, void *y, int k);
184 
185 // Batch Q8_K quantization (row-major output for GEMM compatibility)
186 void quantize_batch_q8_k(const float *x, void *y, int num_rows, int k);
187 
188 void gemv_q4_k_q8_k(float *y,
189  const void *W,
190  const void *x_q8,
191  int M, int K);
192 
193 /* Reference implementation (no SIMD) - for testing/comparison */
194 void gemv_q4_k_q8_k_ref(float *y,
195  const void *W,
196  const void *x_q8,
197  int M, int K);
198 
199 /* Parallel version: receives ith (thread index) and nth (total threads).
200  * OpenMP is at orchestration level, kernel processes only rows [r0, r1). */
201 void gemv_q4_k_q8_k_parallel(float *y,
202  const void *W,
203  const void *x_q8,
204  int M, int K,
205  int ith, int nth);
206 
207 /* Parallel SIMD version: combines AVX with parallel row splitting.
208  * Includes row-ahead prefetching to hide memory latency (~50-70ns).
209  * This is the fastest option for multi-threaded decode. */
210 void gemv_q4_k_q8_k_parallel_simd(float *y,
211  const void *W,
212  const void *x_q8,
213  int M, int K,
214  int ith, int nth);
215 
216 void gemm_q4_k_q8_k(float *Y,
217  const void *W,
218  const void *X_q8,
219  int M, int N, int K);
220 
221 void gemm_nt_q4_k_q8_k(const void *A_q8,
222  const void *B,
223  const float *bias,
224  float *C,
225  int M, int N, int K);
226 
227 // Q6_K x Q8_K quantized kernels
228 void vec_dot_q6_k_q8_k(int n, float *s, const void *vx, const void *vy);
229 
230 void gemv_q6_k_q8_k(float *y,
231  const void *W,
232  const void *x_q8,
233  int M, int K);
234 
235 /* Parallel Q6_K versions - caller provides ith/nth from OpenMP region */
236 void gemv_q6_k_q8_k_parallel(float *y, const void *W, const void *x_q8,
237  int M, int K, int ith, int nth);
238 void gemv_q6_k_q8_k_parallel_simd(float *y, const void *W, const void *x_q8,
239  int M, int K, int ith, int nth);
240 
241 void gemm_q6_k_q8_k(float *Y,
242  const void *W,
243  const void *X_q8,
244  int M, int N, int K);
245 
246 void gemm_nt_q6_k_q8_k(const void *A_q8,
247  const void *B,
248  const float *bias,
249  float *C,
250  int M, int N, int K);
251 
252 void gemm_nt_q8_0_q8_0(const void *A_q8,
253  const void *B,
254  const float *bias,
255  float *C,
256  int M, int N, int K);
257 
258 // GEMM_NN: C[M,N] = A[M,K] @ B[K,N] + bias[N]
259 // B is stored row-major as [K,N] (no transpose on B)
260 // Used for backward d_input = d_output @ W
261 void gemm_nn_parallel(const float *A,
262  const float *B,
263  const float *bias,
264  float *C,
265  int M, int N, int K);
266 
267 void gemm_nn_avx512(const float *A,
268  const float *B,
269  const float *bias,
270  float *C,
271  int M, int N, int K);
272 
273 void gemm_nn_blocked(const float *A,
274  const float *B,
275  const float *bias,
276  float *C,
277  int M, int N, int K);
278 
279 // Head-major output projection (reads attention output directly, no flatten)
280 // Reads attn_out [num_heads, tokens, head_dim] with strided access
281 void ck_gemm_nt_head_major_q5_0(const float *attn_out,
282  const void *wo,
283  const float *bias,
284  float *output,
285  int tokens,
286  int embed_dim,
287  int num_heads,
288  int head_dim);
289 
290 void ck_gemm_nt_head_major_q8_0(const float *attn_out,
291  const void *wo,
292  const float *bias,
293  float *output,
294  int tokens,
295  int embed_dim,
296  int num_heads,
297  int head_dim);
298 
299 // GEMM_TN: C[M,N] = A[K,M].T @ B[K,N] + bias[N]
300 // A is stored row-major as [K,M], B is stored row-major as [K,N]
301 // Used for backward d_W = d_output.T @ input
302 void gemm_tn_parallel(const float *A,
303  const float *B,
304  const float *bias,
305  float *C,
306  int M, int N, int K);
307 
308 void gemm_tn_avx512(const float *A,
309  const float *B,
310  const float *bias,
311  float *C,
312  int M, int N, int K);
313 
314 void gemm_tn_blocked(const float *A,
315  const float *B,
316  const float *bias,
317  float *C,
318  int M, int N, int K);
319 
320 // Fused GEMM operations (GEMM + bias + activation in one pass)
321 void gemm_bias_relu_fused(const float *A,
322  const float *B,
323  const float *bias,
324  float *C,
325  int M, int N, int K);
326 
327 void gemm_bias_gelu_fused(const float *A,
328  const float *B,
329  const float *bias,
330  float *C,
331  int M, int N, int K);
332 
333 void gemm_bias_silu_fused(const float *A,
334  const float *B,
335  const float *bias,
336  float *C,
337  int M, int N, int K);
338 
339 // Fused GEMM + SwiGLU (LLaMA/SmolLM MLP gate+up projection)
340 // Computes: output = SiLU(x @ W_gate + b_gate) * (x @ W_up + b_up)
341 // Two GEMMs + SwiGLU fused into one pass - intermediates stay in registers
342 void gemm_swiglu_fused(const float *x,
343  const float *W_gate,
344  const float *W_up,
345  const float *b_gate, // can be NULL
346  const float *b_up, // can be NULL
347  float *output,
348  int M, int N, int K);
349 
350 // =============================================================================
351 // Fully Fused MLP Decode Kernels (T=1 token generation)
352 // =============================================================================
353 // These kernels fuse the ENTIRE MLP block: Gate + Up + SwiGLU + Down
354 // Key benefit: Intermediate swiglu values stay in L1/L2, never touch DRAM
355 // Target: AVX-512 / Intel Xeon 5th Gen (Sapphire/Emerald Rapids)
356 
357 // Version 1: Tiled fusion with thread-local accumulators
358 // Best for: Small number of cores, when critical section overhead is low
360  const float *x, // [D] input
361  const float *W_gate, // [Hff, D] gate projection
362  const float *W_up, // [Hff, D] up projection
363  const float *W_down, // [D, Hff] down projection
364  const float *b_gate, // [Hff] or NULL
365  const float *b_up, // [Hff] or NULL
366  const float *b_down, // [D] or NULL
367  float *output, // [D] output
368  int D, // hidden dimension
369  int Hff); // intermediate dimension
370 
371 // Version 2: Two-phase (swiglu then down projection)
372 // Best for: Many cores (24+), avoids critical section, better parallelism
374  const float *x,
375  const float *W_gate,
376  const float *W_up,
377  const float *W_down,
378  const float *b_gate,
379  const float *b_up,
380  const float *b_down,
381  float *output,
382  int D,
383  int Hff);
384 
385 // Version 3: Tiled with atomic accumulation
386 // Best for: Large L2 cache (2MB+), good cache reuse
388  const float *x,
389  const float *W_gate,
390  const float *W_up,
391  const float *W_down,
392  const float *b_gate,
393  const float *b_up,
394  const float *b_down,
395  float *output,
396  int D,
397  int Hff);
398 
399 /* ============================================================================
400  * PREFILL FUSION KERNELS
401  * ============================================================================
402  * These kernels fuse operations for prefill (large batch/sequence) to avoid
403  * writing intermediate activations to DRAM. Fusion helps when activations
404  * exceed L3 cache size.
405  *
406  * For decode (single token), use the non-fused kernels as activations
407  * easily fit in L2 cache anyway.
408  */
409 
410 /**
411  * @brief Fused RMSNorm + QKV projection for prefill
412  *
413  * Tiles along token dimension to keep intermediate x_norm in L2 cache.
414  * Avoids ~7MB DRAM traffic per layer for seq_len=1024, hidden=896.
415  *
416  * @param scratch Temporary buffer from fused_rmsnorm_qkv_scratch_size()
417  */
419  const float *x, /* [seq_len × hidden] input */
420  const float *gamma, /* [hidden] RMSNorm weights */
421  const float *Wq, /* [q_dim × hidden] Q weights (transposed) */
422  const float *Wk, /* [kv_dim × hidden] K weights (transposed) */
423  const float *Wv, /* [kv_dim × hidden] V weights (transposed) */
424  float *Q, /* [seq_len × q_dim] output */
425  float *K, /* [seq_len × kv_dim] output */
426  float *V, /* [seq_len × kv_dim] output */
427  int seq_len,
428  int hidden,
429  int q_dim,
430  int kv_dim,
431  float eps,
432  float *scratch);
433 
434 /**
435  * @brief Fused RMSNorm + QKV projection for prefill (head-major outputs)
436  *
437  * Writes Q as [num_heads, seq_len, aligned_head_dim] and K/V with stride
438  * kv_stride_tokens for KV-cache compatibility.
439  */
441  const float *x,
442  const float *gamma,
443  const float *Wq, const float *Bq,
444  const float *Wk, const float *Bk,
445  const float *Wv, const float *Bv,
446  float *Q,
447  float *K,
448  float *V,
449  int seq_len,
450  int embed_dim,
451  int aligned_embed_dim,
452  int num_heads,
453  int num_kv_heads,
454  int head_dim,
455  int aligned_head_dim,
456  int kv_stride_tokens,
457  float eps,
458  float *scratch);
459 
460 /**
461  * @brief Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
462  *
463  * Supports Q5_0 or Q8_0 weights with Q8_0 activations.
464  */
466  const float *x,
467  const float *gamma,
468  const void *Wq, const float *Bq, CKDataType wq_dt,
469  const void *Wk, const float *Bk, CKDataType wk_dt,
470  const void *Wv, const float *Bv, CKDataType wv_dt,
471  float *Q,
472  float *K,
473  float *V,
474  int seq_len,
475  int embed_dim,
476  int aligned_embed_dim,
477  int num_heads,
478  int num_kv_heads,
479  int head_dim,
480  int aligned_head_dim,
481  int kv_stride_tokens,
482  float eps,
483  void *scratch);
484 
485 /** @brief Unfused version for benchmarking comparison */
487  const float *x,
488  const float *gamma,
489  const float *Wq,
490  const float *Wk,
491  const float *Wv,
492  float *x_norm, /* [seq_len × hidden] intermediate buffer */
493  float *Q,
494  float *K,
495  float *V,
496  int seq_len,
497  int hidden,
498  int q_dim,
499  int kv_dim,
500  float eps);
501 
502 /** @brief Get scratch buffer size for fused_rmsnorm_qkv_prefill */
503 size_t fused_rmsnorm_qkv_scratch_size(int hidden);
504 
505 /** @brief Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant */
506 size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim);
507 
508 /**
509  * @brief Fused MLP (Gate + Up + SwiGLU + Down) for prefill
510  *
511  * Tiles along token dimension to keep gate/up/hidden in L3 cache.
512  *
513  * @param scratch Temporary buffer from fused_mlp_swiglu_scratch_size()
514  */
516  const float *x, /* [seq_len × hidden] input */
517  const float *W_gate, /* [intermediate × hidden] (transposed) */
518  const float *W_up, /* [intermediate × hidden] (transposed) */
519  const float *W_down, /* [hidden × intermediate] (transposed) */
520  float *output, /* [seq_len × hidden] output */
521  int seq_len,
522  int hidden,
523  int intermediate,
524  float *scratch);
525 
526 /**
527  * @brief Fused MLP (Gate + Up + SwiGLU + Down) for prefill with biases
528  */
530  const float *x,
531  const float *W_gate,
532  const float *W_up,
533  const float *W_down,
534  const float *B_gate,
535  const float *B_up,
536  const float *B_down,
537  float *output,
538  int seq_len,
539  int hidden,
540  int intermediate,
541  float *scratch);
542 
543 /** @brief Get scratch buffer size for fused_mlp_swiglu_prefill */
544 size_t fused_mlp_swiglu_scratch_size(int intermediate);
545 
546 /**
547  * @brief Quantized fused MLP for prefill (W1=gate+up, W2=down)
548  *
549  * W1 uses Q8_0 activations (Q5_0/Q8_0 weights), W2 uses Q8_K activations
550  * (Q4_K/Q6_K weights).
551  */
553  const float *x,
554  const void *W1,
555  const float *B1,
556  CKDataType w1_dt,
557  const void *W2,
558  const float *B2,
559  CKDataType w2_dt,
560  float *output,
561  int seq_len,
562  int embed_dim,
563  int aligned_embed_dim,
564  int intermediate_dim,
565  int aligned_intermediate_dim,
566  void *scratch);
567 
568 /** @brief Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant */
569 size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(int aligned_embed_dim,
570  int aligned_intermediate_dim);
571 
572 // High-performance GEMM microkernel with 8x8 register blocking
573 // Inspired by oneDNN/BLIS - keeps all 64 accumulator values in registers
574 // C[M,N] = A[M,K] @ B[K,N] or C[M,N] = A[M,K] @ B[N,K].T
575 // B_transposed: 0 = B is [K,N], 1 = B is [N,K] (transposed, common in NN weights)
576 void gemm_microkernel(const float *A,
577  const float *B,
578  float *C,
579  int M, int N, int K,
580  int B_transposed);
581 
582 // Cache-blocked GEMM using 8x8 microkernels (B not transposed)
583 void gemm_microkernel_blocked(const float *A,
584  const float *B,
585  float *C,
586  int M, int N, int K);
587 
588 // Cache-blocked GEMM for B transposed (common in NN)
589 void gemm_microkernel_blocked_bt(const float *A,
590  const float *B,
591  float *C,
592  int M, int N, int K);
593 
594 // Optimized GEMM with matrix packing (best for large matrices)
595 // Packs A and B into contiguous layouts for optimal cache access
596 void gemm_microkernel_packed(const float *A,
597  const float *B,
598  float *C,
599  int M, int N, int K);
600 
601 // LayerNorm forward kernels, copied from C-Transformer.
602 void layernorm_naive_serial(const float *input,
603  const float *gamma,
604  const float *beta,
605  float *output,
606  float *mean_cache,
607  float *rstd_cache,
608  int tokens, int d_model, int aligned_embed_dim,
609  float eps);
610 
611 void layernorm_forward_rolled_slice(const float *__restrict input_slice_base,
612  const float *__restrict gamma,
613  const float *__restrict beta,
614  float *__restrict output_slice_base,
615  float *__restrict mean_cache_slice,
616  float *__restrict rstd_cache_slice,
617  int num_tokens_in_slice,
618  int d_model,
619  int aligned_embed_dim,
620  float eps);
621 
622 /* BF16 LayerNorm forward (rolled) - caller provides scratch buffers */
623 void layernorm_forward_rolled_slice_bf16(const uint16_t *__restrict input_slice_base,
624  const float *__restrict gamma,
625  const float *__restrict beta,
626  uint16_t *__restrict output_slice_base,
627  float *__restrict mean_cache_slice,
628  float *__restrict rstd_cache_slice,
629  int num_tokens_in_slice,
630  int d_model,
631  int aligned_embed_dim,
632  float eps,
633  float *scratch_input, /* [num_tokens * aligned_embed_dim] */
634  float *scratch_output); /* [num_tokens * aligned_embed_dim] */
635 
636 void layernorm_forward_unrolled_slice(const float *__restrict input_slice_base,
637  const float *__restrict gamma,
638  const float *__restrict beta,
639  float *__restrict output_slice_base,
640  float *__restrict mean_cache_slice,
641  float *__restrict rstd_cache_slice,
642  int num_tokens_in_slice,
643  int d_model,
644  float eps);
645 
646 /* BF16 LayerNorm forward (unrolled) - caller provides scratch buffers */
647 void layernorm_forward_unrolled_slice_bf16(const uint16_t *__restrict input_slice_base,
648  const float *__restrict gamma,
649  const float *__restrict beta,
650  uint16_t *__restrict output_slice_base,
651  float *__restrict mean_cache_slice,
652  float *__restrict rstd_cache_slice,
653  int num_tokens_in_slice,
654  int d_model,
655  float eps,
656  float *scratch_input, /* [num_tokens * d_model] */
657  float *scratch_output); /* [num_tokens * d_model] */
658 
659 void layernorm_naive_serial_matched_precision(const float *input,
660  const float *gamma,
661  const float *beta,
662  float *output,
663  float *mean_cache,
664  float *rstd_cache,
665  int tokens, int d_model, float eps);
666 
667 void layernorm_backward_kernel(const float *d_output,
668  const float *input,
669  const float *gamma,
670  const float *mean,
671  const float *rstd,
672  float *d_input,
673  float *d_gamma,
674  float *d_beta,
675  int tokens, int d_model, int aligned_embed_dim);
676 
677 /* BF16 LayerNorm backward - caller provides scratch buffers */
678 void layernorm_backward_kernel_bf16(const uint16_t *d_output,
679  const uint16_t *input,
680  const float *gamma,
681  const float *mean,
682  const float *rstd,
683  uint16_t *d_input,
684  float *d_gamma,
685  float *d_beta,
686  int tokens, int d_model, int aligned_embed_dim,
687  float *scratch_d_output, /* [tokens * aligned_embed_dim] */
688  float *scratch_input, /* [tokens * aligned_embed_dim] */
689  float *scratch_d_input); /* [tokens * aligned_embed_dim] */
690 
691 // RMSNorm forward/backward kernels.
692 void rmsnorm_forward(const float *input,
693  const float *gamma,
694  float *output,
695  float *rstd_cache,
696  int tokens,
697  int d_model,
698  int aligned_embed_dim,
699  float eps);
700 
701 void rmsnorm_backward(const float *d_output,
702  const float *input,
703  const float *gamma,
704  const float *rstd_cache,
705  float *d_input,
706  float *d_gamma,
707  int tokens,
708  int d_model,
709  int aligned_embed_dim);
710 
711 void rmsnorm_forward_bf16(const uint16_t *input,
712  const float *gamma,
713  uint16_t *output,
714  float *rstd_cache,
715  int tokens,
716  int d_model,
717  int aligned_embed_dim,
718  float eps);
719 
720 void rmsnorm_backward_bf16(const uint16_t *d_output,
721  const uint16_t *input,
722  const float *gamma,
723  const float *rstd_cache,
724  uint16_t *d_input,
725  float *d_gamma,
726  int tokens,
727  int d_model,
728  int aligned_embed_dim);
729 
730 /* INT8 RMSNorm forward - caller provides scratch buffers */
731 void rmsnorm_forward_int8(const int8_t *input,
732  const float *gamma,
733  int8_t *output,
734  float *rstd_cache,
735  int tokens,
736  int d_model,
737  int aligned_embed_dim,
738  float eps,
739  float *scratch_input, /* [tokens * aligned_embed_dim] */
740  float *scratch_output); /* [tokens * aligned_embed_dim] */
741 
742 /* INT8 RMSNorm backward - caller provides scratch buffers */
743 void rmsnorm_backward_int8(const int8_t *d_output,
744  const int8_t *input,
745  const float *gamma,
746  const float *rstd_cache,
747  int8_t *d_input,
748  float *d_gamma,
749  int tokens,
750  int d_model,
751  int aligned_embed_dim,
752  float *scratch_d_output, /* [tokens * aligned_embed_dim] */
753  float *scratch_input, /* [tokens * aligned_embed_dim] */
754  float *scratch_d_input); /* [tokens * aligned_embed_dim] */
755 
756 /* INT4 RMSNorm forward - caller provides scratch buffers */
757 void rmsnorm_forward_int4(const uint8_t *input,
758  const float *gamma,
759  uint8_t *output,
760  float *rstd_cache,
761  int tokens,
762  int d_model,
763  int aligned_embed_dim,
764  float eps,
765  float *scratch_input, /* [tokens * aligned_embed_dim] */
766  float *scratch_output); /* [tokens * aligned_embed_dim] */
767 
768 /* INT4 RMSNorm backward - caller provides scratch buffers */
769 void rmsnorm_backward_int4(const uint8_t *d_output,
770  const uint8_t *input,
771  const float *gamma,
772  const float *rstd_cache,
773  uint8_t *d_input,
774  float *d_gamma,
775  int tokens,
776  int d_model,
777  int aligned_embed_dim,
778  float *scratch_d_output, /* [tokens * aligned_embed_dim] */
779  float *scratch_input, /* [tokens * aligned_embed_dim] */
780  float *scratch_d_input); /* [tokens * aligned_embed_dim] */
781 
782 // GELU forward kernel (fast approximation), copied from C-Transformer.
783 void gelu_fast_inplace(float *data, size_t n);
784 
785 // Scalar-only exact GELU forward using standard library tanhf.
786 // Slower but provides maximum accuracy. Used by BF16 wrapper.
787 void gelu_exact_inplace(float *data, size_t n);
788 
789 // GELU backward using tanh-based derivative (vectorized, uses fast tanh approx).
790 void gelu_backward_exact(const float *input,
791  const float *d_output,
792  float *d_input,
793  size_t n);
794 
795 // Scalar-only exact GELU backward using standard library tanhf.
796 // Slower but provides maximum accuracy. Used by BF16 wrapper.
797 void gelu_backward_scalar(const float *input,
798  const float *d_output,
799  float *d_input,
800  size_t n);
801 
802 void gelu_backward_fast(const float *input,
803  const float *d_output,
804  float *d_input,
805  size_t n);
806 
807 // BF16 variants relying on the same floating-point logic.
808 /* BF16 GELU - caller provides scratch buffer [n] floats */
809 void gelu_fast_inplace_bf16(uint16_t *data, size_t n, float *scratch);
810 void gelu_backward_exact_bf16(const uint16_t *input,
811  const uint16_t *d_output,
812  uint16_t *d_input,
813  size_t n,
814  float *scratch_input,
815  float *scratch_d_output,
816  float *scratch_d_input);
817 void gelu_backward_fast_bf16(const uint16_t *input,
818  const uint16_t *d_output,
819  uint16_t *d_input,
820  size_t n,
821  float *scratch_input,
822  float *scratch_d_output,
823  float *scratch_d_input);
824 
825 // GeGLU: out = GELU(a) * b where x = [a, b] along last dimension
826 // Input shape: [tokens, 2 * dim], Output shape: [tokens, dim]
827 void geglu_forward_fp32(const float *x, float *out, int tokens, int dim);
828 void geglu_forward_bf16(const uint16_t *x, uint16_t *out, int tokens, int dim, float *scratch);
829 void geglu_backward_fp32(const float *x,
830  const float *d_out,
831  float *d_x,
832  int tokens,
833  int dim);
834 
835  // ReLU kernels.
836  void relu_forward(const float *input, float *output, size_t n);
837  void relu_forward_inplace(float *data, size_t n);
838  void relu_backward(const float *input,
839  const float *d_output,
840  float *d_input,
841  size_t n);
842 
843  void relu_forward_bf16(const uint16_t *input, uint16_t *output, size_t n);
844  void relu_forward_inplace_bf16(uint16_t *data, size_t n);
845  void relu_backward_bf16(const uint16_t *input,
846  const uint16_t *d_output,
847  uint16_t *d_input,
848  size_t n);
849 
850  // Causal softmax kernel on head-major attention scores, copied from C-Transformer.
851  void causal_softmax_head_major(float *scores,
852  int num_heads,
853  int num_tokens,
854  int aligned_context_window);
855 
856  // Scalar-only exact causal softmax using standard library expf.
857  // Slower but provides maximum accuracy. Used by BF16 attention wrapper.
858  void causal_softmax_head_major_exact(float *scores,
859  int num_heads,
860  int num_tokens,
861  int aligned_context_window);
862 
863  void backward_causal_softmax_head_major(float *d_scores,
864  const float *weights,
865  int num_heads,
866  int num_tokens,
867  int aligned_context_window);
868 
869  /* BF16 causal softmax - caller provides scratch buffer */
870  void causal_softmax_head_major_bf16(uint16_t *scores,
871  int num_heads,
872  int num_tokens,
873  int aligned_context_window,
874  float *scratch); /* [num_heads * aligned_context_window * aligned_context_window] */
875 
876  /* BF16 backward causal softmax - caller provides scratch buffers */
877  void backward_causal_softmax_head_major_bf16(uint16_t *d_scores,
878  const uint16_t *weights,
879  int num_heads,
880  int num_tokens,
881  int aligned_context_window,
882  float *scratch_d_scores, /* [num_heads * aligned_context_window * aligned_context_window] */
883  float *scratch_weights); /* [num_heads * aligned_context_window * aligned_context_window] */
884 
885 // Scaled dot-product attention (causal) in head-major layout.
886 // Q/K/V layout: [head][token][head_dim] with stride aligned_head_dim.
887 // scores: [head][query_token][key_token] with stride aligned_context_window.
888 // output: same layout as Q/V.
889 void attention_forward_causal_head_major(const float *q,
890  const float *k,
891  const float *v,
892  float *scores,
893  float *output,
894  int num_heads,
895  int num_tokens,
896  int head_dim,
897  int aligned_head_dim,
898  int aligned_context_window);
899 
900 // Exact version using standard library expf (slower but accurate).
902  const float *k,
903  const float *v,
904  float *scores,
905  float *output,
906  int num_heads,
907  int num_tokens,
908  int head_dim,
909  int aligned_head_dim,
910  int aligned_context_window);
911 
912 // GQA-aware attention: Q has num_heads, K/V have num_kv_heads.
913 void attention_forward_causal_head_major_gqa(const float *q,
914  const float *k,
915  const float *v,
916  float *scores,
917  float *output,
918  int num_heads,
919  int num_kv_heads,
920  int num_tokens,
921  int head_dim,
922  int aligned_head_dim,
923  int aligned_context_window);
924 
925 // Exact GQA version using standard library expf (slower but accurate).
927  const float *k,
928  const float *v,
929  float *scores,
930  float *output,
931  int num_heads,
932  int num_kv_heads,
933  int num_tokens,
934  int head_dim,
935  int aligned_head_dim,
936  int aligned_context_window);
937 
938 /* BF16 attention forward - caller provides scratch buffers (no internal malloc) */
939 void attention_forward_causal_head_major_gqa_bf16(const uint16_t *q,
940  const uint16_t *k,
941  const uint16_t *v,
942  float *scores,
943  float *output,
944  int num_heads,
945  int num_kv_heads,
946  int num_tokens,
947  int head_dim,
948  int aligned_head_dim,
949  int aligned_context_window,
950  float *scratch_q, /* [num_heads * num_tokens * aligned_head_dim] */
951  float *scratch_k, /* [num_kv_heads * num_tokens * aligned_head_dim] */
952  float *scratch_v); /* [num_kv_heads * num_tokens * aligned_head_dim] */
953 
954 // Flash-style causal attention forward (no score/weight matrix materialization).
955 // Head-major layout:
956 // Q: [num_heads, num_tokens, aligned_head_dim]
957 // K/V: [num_kv_heads, num_tokens, aligned_head_dim]
958 // out: [num_heads, num_tokens, aligned_head_dim]
960  const float *k,
961  const float *v,
962  float *output,
963  int num_heads,
964  int num_kv_heads,
965  int num_tokens,
966  int head_dim,
967  int aligned_head_dim);
968 
970  const float *k,
971  const float *v,
972  float *output,
973  int num_heads,
974  int num_kv_heads,
975  int num_tokens,
976  int head_dim,
977  int aligned_head_dim,
978  int kv_stride_tokens);
979 
980 // Decode attention for a single token using a KV cache (flash-style).
981 // q_token: [num_heads, aligned_head_dim]
982 // k_cache/v_cache: [num_kv_heads, cache_capacity, aligned_head_dim]
983 // out_token: [num_heads, aligned_head_dim]
984 void attention_forward_decode_head_major_gqa_flash(const float *q_token,
985  const float *k_cache,
986  const float *v_cache,
987  float *out_token,
988  int num_heads,
989  int num_kv_heads,
990  int kv_tokens,
991  int cache_capacity,
992  int head_dim,
993  int aligned_head_dim);
994 
995 // Decode attention for a single token using a KV cache (REGULAR - NOT flash).
996 // q_token: [num_heads, aligned_head_dim]
997 // k_cache/v_cache: [num_kv_heads, cache_capacity, aligned_head_dim]
998 // out_token: [num_heads, aligned_head_dim]
999 // WARNING: This is O(n) complexity, not true flash attention!
1000 void attention_forward_decode_head_major_gqa_regular(const float *q_token,
1001  const float *k_cache,
1002  const float *v_cache,
1003  float *out_token,
1004  int num_heads,
1005  int num_kv_heads,
1006  int kv_tokens,
1007  int cache_capacity,
1008  int head_dim,
1009  int aligned_head_dim);
1010 
1011 // Sliding-window attention forward (prefill, flash-style)
1012 // Each token attends to the last `sliding_window` tokens.
1013 // sliding_window: window size (0 or negative = no limit, like regular causal)
1015  const float *q,
1016  const float *k,
1017  const float *v,
1018  float *output,
1019  int num_heads,
1020  int num_kv_heads,
1021  int num_tokens,
1022  int head_dim,
1023  int aligned_head_dim,
1024  int kv_stride_tokens,
1025  int sliding_window);
1026 
1027 // Sliding-window attention forward (decode, flash-style)
1028 // Single query token attends to the last `sliding_window` tokens in KV cache.
1030  const float *q_token,
1031  const float *k_cache,
1032  const float *v_cache,
1033  float *out_token,
1034  int num_heads,
1035  int num_kv_heads,
1036  int kv_tokens,
1037  int cache_capacity,
1038  int head_dim,
1039  int aligned_head_dim,
1040  int sliding_window);
1041 
1042 // TRUE Flash Attention (O(1) for decode) - Tri Dao's algorithm
1043 // out: [T_q, H, D_h]
1044 // q: [T_q, H, D_h]
1045 // k: [T_k, H, D_h]
1046 // v: [T_k, H, D_h]
1047 // T_q: Query tokens (1 for decode)
1048 // T_k: Context length
1049 // H: Number of heads
1050 // D_h: Head dimension
1051 // scale: 1/sqrt(D_h)
1052 void attention_flash_decode(float *out,
1053  const float *q,
1054  const float *k,
1055  const float *v,
1056  int T_q,
1057  int T_k,
1058  int H,
1059  int D_h,
1060  float scale);
1061 
1062 // Diagnostics for flash attention tuning (used by unit tests).
1063 int ck_flash_attn_choose_tile_k(int D_h);
1064 int ck_flash_attn_fast_exp_kind(void);
1065 
1066 // Orchestration wrapper for TRUE flash attention
1067 void ck_attention_flash_decode_wrapper(const float *q_token,
1068  const float *k_cache,
1069  const float *v_cache,
1070  float *out_token,
1071  int num_heads,
1072  int num_kv_heads,
1073  int kv_tokens,
1074  int cache_capacity,
1075  int head_dim,
1076  int aligned_head_dim);
1077 
1078 // KV cache helper (write one token for all KV heads).
1079 void kv_cache_write_head_major(const float *__restrict k_token,
1080  const float *__restrict v_token,
1081  float *__restrict k_cache,
1082  float *__restrict v_cache,
1083  int num_kv_heads,
1084  int token_index,
1085  int cache_capacity,
1086  int head_dim,
1087  int aligned_head_dim);
1088 
1089 void kv_cache_store(float *__restrict kv_cache_k,
1090  float *__restrict kv_cache_v,
1091  const float *__restrict k,
1092  const float *__restrict v,
1093  int layer,
1094  int pos,
1095  int num_kv_heads,
1096  int head_dim,
1097  int max_seq_len);
1098 
1099 // Repack a head-major tensor from a packed `[head, tokens, aligned_head_dim]`
1100 // layout into a KV-cache-compatible layout `[head, cache_capacity, aligned_head_dim]`
1101 // in-place. This is used after prefill when forward kernels write head slices
1102 // back-to-back using `tokens` as the head stride, but decode expects a fixed
1103 // `cache_capacity` stride.
1104 void kv_cache_repack_head_major_inplace(float *buf,
1105  int num_heads,
1106  int tokens,
1107  int cache_capacity,
1108  int aligned_head_dim);
1109 
1110 // MLP forward kernel (FC1 -> GELU -> FC2), generic token-parallel version.
1111 void mlp_token_parallel(const float *input,
1112  const float *W_fc1,
1113  const float *b_fc1,
1114  const float *W_fc2,
1115  const float *b_fc2,
1116  float *fc1_output,
1117  float *output,
1118  int T,
1119  int aligned_dim,
1120  int num_threads);
1121 
1122 // Exact version using scalar GELU with standard library tanhf.
1123 // Slower but provides maximum accuracy. Used for correctness testing.
1124 void mlp_token_parallel_exact(const float *input,
1125  const float *W_fc1,
1126  const float *b_fc1,
1127  const float *W_fc2,
1128  const float *b_fc2,
1129  float *fc1_output,
1130  float *output,
1131  int T,
1132  int aligned_dim,
1133  int num_threads);
1134 
1135 /* BF16 MLP forward - caller provides scratch buffers */
1136 void mlp_token_parallel_bf16(const uint16_t *input,
1137  const uint16_t *W_fc1,
1138  const uint16_t *b_fc1,
1139  const uint16_t *W_fc2,
1140  const uint16_t *b_fc2,
1141  float *fc1_output,
1142  float *output,
1143  int T,
1144  int aligned_dim,
1145  int num_threads,
1146  float *scratch_bias1_f, /* [4*D] */
1147  float *scratch_bias2_f, /* [D] */
1148  uint16_t *scratch_fc1_bf16); /* [T * 4*D] */
1149 
1150 /* BF16 MLP forward with FP32 activations - caller provides scratch buffers */
1151 void mlp_token_parallel_bf16_fp32act(const uint16_t *input,
1152  const uint16_t *W_fc1,
1153  const uint16_t *b_fc1,
1154  const uint16_t *W_fc2,
1155  const uint16_t *b_fc2,
1156  float *fc1_output,
1157  float *output,
1158  int T,
1159  int aligned_dim,
1160  int num_threads,
1161  float *scratch_input_f, /* [T * D] */
1162  float *scratch_bias1_f, /* [4*D] */
1163  float *scratch_bias2_f, /* [D] */
1164  uint16_t *scratch_fc1_bf16); /* [T * 4*D] */
1165 
1166 // MLP FC1/FC2 backward kernels (generic), adapted from C-Transformer.
1167 void fc2_backward_kernel(const float *d_output,
1168  const float *fc2_input,
1169  const float *W_fc2,
1170  float *d_input,
1171  float *d_W_fc2,
1172  float *d_b_fc2,
1173  int T,
1174  int aligned_in,
1175  int aligned_out,
1176  int num_threads);
1177 
1178 void fc1_backward_kernel(const float *d_output,
1179  const float *fc1_input,
1180  const float *W_fc1,
1181  float *d_input,
1182  float *d_W_fc1,
1183  float *d_b_fc1,
1184  int T,
1185  int aligned_in,
1186  int aligned_out,
1187  int num_threads);
1188 
1189 // Sigmoid kernels (scalar + vector forms).
1190 float sigmoid_scalar(float x);
1191 
1192 void sigmoid_forward(const float *input,
1193  float *output,
1194  size_t n);
1195 
1196 void sigmoid_backward(const float *input,
1197  const float *d_output,
1198  float *d_input,
1199  size_t n);
1200 
1201 /* BF16 sigmoid - caller provides scratch buffers [n] floats each */
1202 void sigmoid_forward_bf16(const uint16_t *input,
1203  uint16_t *output,
1204  size_t n,
1205  float *scratch_input,
1206  float *scratch_output);
1207 
1208 void sigmoid_backward_bf16(const uint16_t *input,
1209  const uint16_t *d_output,
1210  uint16_t *d_input,
1211  size_t n,
1212  float *scratch_input,
1213  float *scratch_d_output,
1214  float *scratch_d_input);
1215 
1216 // SwiGLU activation kernels (forward + backward).
1217 // Input layout per token: [gate[0..D-1], value[0..D-1]], size 2*D.
1218 // Output: [D].
1219  void swiglu_forward(const float *input,
1220  float *output,
1221  int tokens,
1222  int dim);
1223 
1224  void swiglu_backward(const float *input,
1225  const float *d_output,
1226  float *d_input,
1227  int tokens,
1228  int dim);
1229 
1230  // Exact versions using standard library expf (slower but accurate)
1231  void swiglu_forward_exact(const float *input,
1232  float *output,
1233  int tokens,
1234  int dim);
1235 
1236  void swiglu_backward_exact(const float *input,
1237  const float *d_output,
1238  float *d_input,
1239  int tokens,
1240  int dim);
1241 
1242  void swiglu_forward_bf16(const uint16_t *input,
1243  uint16_t *output,
1244  int tokens,
1245  int dim);
1246 
1247  void swiglu_backward_bf16(const uint16_t *input,
1248  const uint16_t *d_output,
1249  uint16_t *d_input,
1250  int tokens,
1251  int dim);
1252 
1253 // =============================================================================
1254 // Element-wise addition kernels (for residual connections)
1255 // =============================================================================
1256 
1257 // Forward: y = a + b
1258 void add_forward_bf16(const uint16_t *a,
1259  const uint16_t *b,
1260  uint16_t *y,
1261  size_t n);
1262 
1263 // Forward with scale: y = a + alpha * b
1264 void add_scaled_forward_bf16(const uint16_t *a,
1265  const uint16_t *b,
1266  uint16_t *y,
1267  float alpha,
1268  size_t n);
1269 
1270 // In-place: a += b
1271 void add_inplace_bf16(uint16_t *a,
1272  const uint16_t *b,
1273  size_t n);
1274 
1275 // In-place scaled: a += alpha * b
1276 void add_scaled_inplace_bf16(uint16_t *a,
1277  const uint16_t *b,
1278  float alpha,
1279  size_t n);
1280 
1281 // Backward: d_a = d_y, d_b = d_y (gradient passthrough)
1282 void add_backward_bf16(const uint16_t *d_y,
1283  uint16_t *d_a,
1284  uint16_t *d_b,
1285  size_t n);
1286 
1287 // 2D version for [tokens, dim] shaped tensors
1288 void add_forward_2d_bf16(const uint16_t *a,
1289  const uint16_t *b,
1290  uint16_t *y,
1291  int tokens,
1292  int dim,
1293  int aligned_dim);
1294 
1295 // FP32 versions
1296 void add_forward_f32(const float *a,
1297  const float *b,
1298  float *y,
1299  size_t n);
1300 
1301 void add_inplace_f32(float *a,
1302  const float *b,
1303  size_t n);
1304 
1305 // =============================================================================
1306 // AXPY kernels (for MoE expert accumulation)
1307 // =============================================================================
1308 
1309 // In-place AXPY: y += alpha * x
1310 void axpy_f32(float *y,
1311  const float *x,
1312  float alpha,
1313  int n);
1314 
1315 // Scaled copy: y = alpha * x
1316 void scal_copy_f32(float *y,
1317  const float *x,
1318  float alpha,
1319  int n);
1320 
1321 // Weighted sum: y = sum_i(weights[i] * vectors[i])
1322 void weighted_sum_f32(float *y,
1323  const float **vectors,
1324  const float *weights,
1325  int k,
1326  int n);
1327 
1328 // Zero-then-accumulate: y = 0; y += alpha * x
1329 void axpy_zero_f32(float *y,
1330  const float *x,
1331  float alpha,
1332  int n);
1333 
1334 // Batched 2D AXPY: Y[t,:] += alpha * X[t,:]
1335 void axpy_2d_f32(float *Y,
1336  const float *X,
1337  float alpha,
1338  int num_tokens,
1339  int dim,
1340  int y_stride,
1341  int x_stride);
1342 
1343 // MoE expert accumulation: output += routing_weight * expert_output
1344 void moe_accumulate_expert_f32(float *output,
1345  const float *expert_output,
1346  float routing_weight,
1347  int hidden_dim);
1348 
1349 // =============================================================================
1350 // Top-K selection kernels (for MoE router dispatch)
1351 // =============================================================================
1352 
1353 // Find top-K indices and values from scores
1354 void topk_f32(const float *scores,
1355  int n,
1356  int k,
1357  int *indices,
1358  float *values);
1359 
1360 // Top-K with softmax-normalized weights
1361 void topk_softmax_f32(const float *scores,
1362  int n,
1363  int k,
1364  int *indices,
1365  float *weights);
1366 
1367 // Batched top-K for multiple tokens
1368 void topk_batched_f32(const float *scores,
1369  int num_tokens,
1370  int n_experts,
1371  int k,
1372  int *indices,
1373  float *weights);
1374 
1375 // Argmax (top-1)
1376 int argmax_f32(const float *scores, int n);
1377 
1378 // Attention backward (GQA-aware): computes d_q, d_k, d_v.
1380  const float *d_output,
1381  const float *q,
1382  const float *k,
1383  const float *v,
1384  const float *attn_weights,
1385  float *d_q,
1386  float *d_k,
1387  float *d_v,
1388  float *d_scores,
1389  int num_heads,
1390  int num_kv_heads,
1391  int num_tokens,
1392  int head_dim,
1393  int aligned_head_dim,
1394  int aligned_context_window);
1395 
1396 // Attention backward (non-GQA): num_kv_heads == num_heads.
1398  const float *d_output,
1399  const float *q,
1400  const float *k,
1401  const float *v,
1402  const float *attn_weights,
1403  float *d_q,
1404  float *d_k,
1405  float *d_v,
1406  float *d_scores,
1407  int num_heads,
1408  int num_tokens,
1409  int head_dim,
1410  int aligned_head_dim,
1411  int aligned_context_window);
1412 
1413 /* BF16 attention backward - caller provides scratch buffers (no internal malloc) */
1415  const uint16_t *d_output,
1416  float *d_x,
1417  const uint16_t *q,
1418  const uint16_t *k,
1419  const uint16_t *v,
1420  const float *attn_weights,
1421  float *d_q,
1422  float *d_k,
1423  float *d_v,
1424  float *d_scores,
1425  int num_heads,
1426  int num_kv_heads,
1427  int num_tokens,
1428  int head_dim,
1429  int aligned_head_dim,
1430  int aligned_context_window,
1431  float *scratch_d_output, /* [num_heads * num_tokens * aligned_head_dim] */
1432  float *scratch_q, /* [num_heads * num_tokens * aligned_head_dim] */
1433  float *scratch_k, /* [num_kv_heads * num_tokens * aligned_head_dim] */
1434  float *scratch_v); /* [num_kv_heads * num_tokens * aligned_head_dim] */
1435 
1436 // RoPE (Rotary Position Embedding) kernels.
1437 // Precompute cos/sin cache: [max_seq_len, head_dim/2].
1438 void rope_precompute_cache(float *cos_cache,
1439  float *sin_cache,
1440  int max_seq_len,
1441  int head_dim,
1442  float base);
1443 
1444 // Apply RoPE forward in-place: x[num_heads, num_tokens, aligned_head_dim].
1445 void rope_forward(float *x,
1446  const float *cos_cache,
1447  const float *sin_cache,
1448  int num_heads,
1449  int num_tokens,
1450  int head_dim,
1451  int aligned_head_dim,
1452  int pos_offset);
1453 
1454 // RoPE backward: inverse rotation.
1455 void rope_backward(const float *d_out,
1456  float *d_x,
1457  const float *cos_cache,
1458  const float *sin_cache,
1459  int num_heads,
1460  int num_tokens,
1461  int head_dim,
1462  int aligned_head_dim,
1463  int pos_offset);
1464 
1465 /* BF16 RoPE forward - caller provides scratch buffer */
1466 void rope_forward_bf16(uint16_t *x,
1467  const float *cos_cache,
1468  const float *sin_cache,
1469  int num_heads,
1470  int num_tokens,
1471  int head_dim,
1472  int aligned_head_dim,
1473  int pos_offset,
1474  float *scratch); /* [num_heads * num_tokens * aligned_head_dim] */
1475 
1476 /* BF16 RoPE backward - caller provides scratch buffers */
1477 void rope_backward_bf16(const uint16_t *d_out,
1478  uint16_t *d_x,
1479  const float *cos_cache,
1480  const float *sin_cache,
1481  int num_heads,
1482  int num_tokens,
1483  int head_dim,
1484  int aligned_head_dim,
1485  int pos_offset,
1486  float *scratch_d_out, /* [num_heads * num_tokens * aligned_head_dim] */
1487  float *scratch_d_x); /* [num_heads * num_tokens * aligned_head_dim] */
1488 
1489 // RoPE backward in-place.
1490 void rope_backward_inplace(float *d_x,
1491  const float *cos_cache,
1492  const float *sin_cache,
1493  int num_heads,
1494  int num_tokens,
1495  int head_dim,
1496  int aligned_head_dim,
1497  int pos_offset);
1498 
1499 void rope_forward_strided(float *x,
1500  const float *cos_cache,
1501  const float *sin_cache,
1502  int num_heads,
1503  int num_tokens,
1504  int head_dim,
1505  int aligned_head_dim,
1506  int pos_offset,
1507  int head_stride_tokens);
1508 
1509 // Combined RoPE for Q and K.
1510  void rope_forward_qk(float *q,
1511  float *k,
1512  const float *cos_cache,
1513  const float *sin_cache,
1514  int num_heads,
1515  int num_kv_heads,
1516  int num_tokens,
1517  int head_dim,
1518  int aligned_head_dim,
1519  int pos_offset);
1520 
1521  void rope_forward_qk_strided(float *q,
1522  float *k,
1523  const float *cos_cache,
1524  const float *sin_cache,
1525  int num_heads,
1526  int num_kv_heads,
1527  int num_tokens,
1528  int head_dim,
1529  int aligned_head_dim,
1530  int pos_offset,
1531  int q_stride_tokens,
1532  int k_stride_tokens);
1533 
1534  void rope_backward_qk(const float *d_q_out,
1535  const float *d_k_out,
1536  float *d_q,
1537  float *d_k,
1538  const float *cos_cache,
1539  const float *sin_cache,
1540  int num_heads,
1541  int num_kv_heads,
1542  int num_tokens,
1543  int head_dim,
1544  int aligned_head_dim,
1545  int pos_offset);
1546 
1547  /* BF16 RoPE forward for Q and K - caller provides scratch buffers */
1548  void rope_forward_qk_bf16(uint16_t *q,
1549  uint16_t *k,
1550  const float *cos_cache,
1551  const float *sin_cache,
1552  int num_heads,
1553  int num_kv_heads,
1554  int num_tokens,
1555  int head_dim,
1556  int aligned_head_dim,
1557  int pos_offset,
1558  float *scratch_q, /* [num_heads * num_tokens * aligned_head_dim] */
1559  float *scratch_k); /* [num_kv_heads * num_tokens * aligned_head_dim] */
1560 
1561  /* BF16 RoPE backward for Q and K - caller provides scratch buffers */
1562  void rope_backward_qk_bf16(const uint16_t *d_q_out,
1563  const uint16_t *d_k_out,
1564  uint16_t *d_q,
1565  uint16_t *d_k,
1566  const float *cos_cache,
1567  const float *sin_cache,
1568  int num_heads,
1569  int num_kv_heads,
1570  int num_tokens,
1571  int head_dim,
1572  int aligned_head_dim,
1573  int pos_offset,
1574  float *scratch_dq_out, /* [num_heads * num_tokens * aligned_head_dim] */
1575  float *scratch_dq, /* [num_heads * num_tokens * aligned_head_dim] */
1576  float *scratch_dk_out, /* [num_kv_heads * num_tokens * aligned_head_dim] */
1577  float *scratch_dk); /* [num_kv_heads * num_tokens * aligned_head_dim] */
1578 
1579 // Token embedding lookup (optionally adds positional embeddings).
1580 // token_embeddings: [vocab_size x aligned_embed_dim]
1581 // pos_embeddings: [context_window x aligned_embed_dim] or NULL if add_pos == 0.
1582 // output: [context_window x aligned_embed_dim]
1583 void embedding_forward(const int32_t *token_ids,
1584  int token_count,
1585  int vocab_size,
1586  const float *token_embeddings,
1587  const float *pos_embeddings,
1588  float *output,
1589  int embed_dim,
1590  int aligned_embed_dim,
1591  int context_window,
1592  int add_pos);
1593 
1594 void embedding_forward_q4_k(const int32_t *token_ids,
1595  int token_count,
1596  int vocab_size,
1597  const void *token_embeddings,
1598  const float *pos_embeddings,
1599  float *output,
1600  int embed_dim,
1601  int aligned_embed_dim,
1602  int context_window,
1603  int add_pos);
1604 
1605 void embedding_forward_q6_k(const int32_t *token_ids,
1606  int token_count,
1607  int vocab_size,
1608  const void *token_embeddings,
1609  const float *pos_embeddings,
1610  float *output,
1611  int embed_dim,
1612  int aligned_embed_dim,
1613  int context_window,
1614  int add_pos);
1615 
1616 void embedding_forward_q8_0(const int32_t *token_ids,
1617  int token_count,
1618  int vocab_size,
1619  const void *token_embeddings,
1620  const float *pos_embeddings,
1621  float *output,
1622  int embed_dim,
1623  int aligned_embed_dim,
1624  int context_window,
1625  int add_pos);
1626 
1627  void embedding_forward_bf16(const int32_t *token_ids,
1628  int token_count,
1629  int vocab_size,
1630  const uint16_t *token_embeddings,
1631  const uint16_t *pos_embeddings,
1632  uint16_t *output,
1633  int embed_dim,
1634  int aligned_embed_dim,
1635  int context_window,
1636  int add_pos);
1637 
1638 // Embedding backward: accumulates into d_token_embeddings and d_pos_embeddings.
1639 // d_output: [context_window x aligned_embed_dim]
1640 // d_token_embeddings: [vocab_size x aligned_embed_dim]
1641 // d_pos_embeddings: [context_window x aligned_embed_dim] (optional)
1642  void embedding_backward(const int32_t *token_ids,
1643  int token_count,
1644  const float *d_output,
1645  float *d_token_embeddings,
1646  float *d_pos_embeddings,
1647  int vocab_size,
1648  int embed_dim,
1649  int aligned_embed_dim,
1650  int context_window,
1651  int add_pos);
1652 
1653  void embedding_backward_bf16(const int32_t *token_ids,
1654  int token_count,
1655  const uint16_t *d_output,
1656  uint16_t *d_token_embeddings,
1657  uint16_t *d_pos_embeddings,
1658  int vocab_size,
1659  int embed_dim,
1660  int aligned_embed_dim,
1661  int context_window,
1662  int add_pos);
1663 
1664 // Softmax cross-entropy loss + gradient w.r.t logits.
1665 // logits: [tokens x vocab_size], targets: [tokens], d_logits: [tokens x vocab_size]
1666  void softmax_cross_entropy_loss(const float *logits,
1667  const int32_t *targets,
1668  int tokens,
1669  int vocab_size,
1670  float *d_logits,
1671  float *loss_out);
1672 
1673  /* BF16 softmax cross-entropy loss - caller provides scratch buffers */
1674  void softmax_cross_entropy_loss_bf16(const uint16_t *logits,
1675  const int32_t *targets,
1676  int tokens,
1677  int vocab_size,
1678  uint16_t *d_logits,
1679  float *loss_out,
1680  float *scratch_logits, /* [tokens * vocab_size] */
1681  float *scratch_d_logits); /* [tokens * vocab_size] */
1682 
1683  // Vision helpers (patchify/unpatchify).
1684  void im2patch(const float *image,
1685  float *patches,
1686  int C, int H, int W, int P);
1687  void patch2im(const float *d_patches,
1688  float *d_image,
1689  int C, int H, int W, int P);
1690 
1691  void im2patch_bf16(const uint16_t *image,
1692  uint16_t *patches,
1693  int C, int H, int W, int P);
1694  void patch2im_bf16(const uint16_t *d_patches,
1695  uint16_t *d_image,
1696  int C, int H, int W, int P);
1697 
1698 #ifdef __cplusplus
1699 } // extern "C"
1700 #endif
1701 
1702 #endif // CKERNEL_ENGINE_H
CKDataType
Supported data types in C-Kernel-Engine.
Definition: ckernel_dtype.h:27
void gemm_nt_q4_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void attention_forward_causal_head_major_gqa_bf16(const uint16_t *q, const uint16_t *k, const uint16_t *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_q, float *scratch_k, float *scratch_v)
void dequant_q4_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_0 row (multiple blocks)
void embedding_forward_q6_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void attention_forward_causal_head_major_gqa_exact(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void axpy_f32(float *y, const float *x, float alpha, int n)
In-place AXPY: y += alpha * x.
Definition: axpy_kernels.c:54
void rmsnorm_forward_int8(const int8_t *input, const float *gamma, int8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
void gemm_q6_k(float *Y, const void *W, const float *X, int M, int N, int K)
void ck_gemm_nt_head_major_q8_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (Q8_0 weights)
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
int ck_flash_attn_choose_tile_k(int D_h)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
Definition: rope_kernels.c:472
void moe_accumulate_expert_f32(float *output, const float *expert_output, float routing_weight, int hidden_dim)
Accumulate expert output: output += routing_weight * expert_output.
Definition: axpy_kernels.c:256
void swiglu_forward_exact(const float *input, float *output, int tokens, int dim)
void rmsnorm_backward_bf16(const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *rstd_cache, uint16_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
void gemv_q5_1(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.
void gemm_naive_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:125
void gemv_fused_q8_0_bias_dispatch(float *y, const void *W, const float *x, const float *bias, int M, int K)
void swiglu_forward(const float *input, float *output, int tokens, int dim)
void gemm_bias_silu_fused(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void backward_causal_softmax_head_major_bf16(uint16_t *d_scores, const uint16_t *weights, int num_heads, int num_tokens, int aligned_context_window, float *scratch_d_scores, float *scratch_weights)
void add_scaled_forward_bf16(const uint16_t *a, const uint16_t *b, uint16_t *y, float alpha, size_t n)
void gemm_swiglu_fused(const float *x, const float *W_gate, const float *W_up, const float *b_gate, const float *b_up, float *output, int M, int N, int K)
void ck_set_num_threads(int num_threads)
void swiglu_backward(const float *input, const float *d_output, float *d_input, int tokens, int dim)
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void gelu_backward_exact(const float *input, const float *d_output, float *d_input, size_t n)
Definition: gelu_kernels.c:257
void ck_gemm_nt_head_major_q5_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (auto-dispatch)
void axpy_zero_f32(float *y, const float *x, float alpha, int n)
Zero output then accumulate: y = 0; y += alpha * x.
Definition: axpy_kernels.c:188
void fused_mlp_swiglu_prefill(const float *x, const float *W_gate, const float *W_up, const float *W_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
Fused MLP (Gate + Up + SwiGLU + Down) for prefill.
void gelu_exact_inplace(float *data, size_t n)
Definition: gelu_kernels.c:446
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void add_inplace_bf16(uint16_t *a, const uint16_t *b, size_t n)
void attention_backward_causal_head_major_gqa_bf16(const uint16_t *d_output, float *d_x, const uint16_t *q, const uint16_t *k, const uint16_t *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_d_output, float *scratch_q, float *scratch_k, float *scratch_v)
void gemm_nn_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:317
void fused_mlp_swiglu_decode_tiled(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:180
void kv_cache_repack_head_major_inplace(float *buf, int num_heads, int tokens, int cache_capacity, int aligned_head_dim)
void fused_mlp_swiglu_decode(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
void gemm_bias_relu_fused(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void layernorm_naive_serial_matched_precision(const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, float eps)
void rope_precompute_cache(float *cos_cache, float *sin_cache, int max_seq_len, int head_dim, float base)
Definition: rope_kernels.c:52
void gemv_q6_k(float *y, const void *W, const float *x, int M, int K)
void topk_batched_f32(const float *scores, int num_tokens, int n_experts, int k, int *indices, float *weights)
Batched top-K selection for multiple tokens.
Definition: topk_kernels.c:191
void backward_causal_softmax_head_major(float *d_scores, const float *weights, int num_heads, int num_tokens, int aligned_context_window)
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(int aligned_embed_dim, int aligned_intermediate_dim)
Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.
void attention_forward_causal_head_major(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void gemv_q8_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q8_0 weights based on CPU features.
void attention_forward_causal_head_major_gqa(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void patch2im(const float *d_patches, float *d_image, int C, int H, int W, int P)
void fused_rmsnorm_qkv_prefill_head_major(const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (head-major outputs)
void swiglu_forward_bf16(const uint16_t *input, uint16_t *output, int tokens, int dim)
void rope_backward_bf16(const uint16_t *d_out, uint16_t *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_d_out, float *scratch_d_x)
void rope_backward_qk(const float *d_q_out, const float *d_k_out, float *d_q, float *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:497
void relu_backward(const float *input, const float *d_output, float *d_input, size_t n)
Definition: relu_kernels.c:84
void mlp_token_parallel_bf16(const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16)
void attention_forward_causal_head_major_gqa_flash_strided_sliding(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window)
void geglu_backward_fp32(const float *x, const float *d_out, float *d_x, int tokens, int dim)
Definition: gelu_kernels.c:843
void embedding_forward_q4_k(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void gemm_nt_q4_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q4_1 weights: C = A @ B^T.
void dequant_q5_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q5_0 row (multiple blocks)
void add_forward_f32(const float *a, const float *b, float *y, size_t n)
void rmsnorm_forward_bf16(const uint16_t *input, const float *gamma, uint16_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void gemv_q5_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q5_0 weights and Q8_0 input.
void add_inplace_f32(float *a, const float *b, size_t n)
void attention_forward_decode_head_major_gqa_flash_sliding(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window)
void relu_forward_inplace_bf16(uint16_t *data, size_t n)
void gemv_q5_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q5_0 weights based on CPU features.
void im2patch_bf16(const uint16_t *image, uint16_t *patches, int C, int H, int W, int P)
void attention_forward_causal_head_major_exact(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void gemm_microkernel(const float *A, const float *B, float *C, int M, int N, int K, int B_transposed)
void mlp_token_parallel(const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads)
Definition: mlp_kernels.c:41
void softmax_cross_entropy_loss_bf16(const uint16_t *logits, const int32_t *targets, int tokens, int vocab_size, uint16_t *d_logits, float *loss_out, float *scratch_logits, float *scratch_d_logits)
void rope_backward_qk_bf16(const uint16_t *d_q_out, const uint16_t *d_k_out, uint16_t *d_q, uint16_t *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_dq_out, float *scratch_dq, float *scratch_dk_out, float *scratch_dk)
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void weighted_sum_f32(float *y, const float **vectors, const float *weights, int k, int n)
Weighted sum of k vectors: y = sum_i(weights[i] * vectors[i])
Definition: axpy_kernels.c:155
void rmsnorm_backward_int4(const uint8_t *d_output, const uint8_t *input, const float *gamma, const float *rstd_cache, uint8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
CKMathBackend ckernel_backend_native(void)
void embedding_forward_bf16(const int32_t *token_ids, int token_count, int vocab_size, const uint16_t *token_embeddings, const uint16_t *pos_embeddings, uint16_t *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void gelu_backward_fast_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input)
void causal_softmax_head_major_exact(float *scores, int num_heads, int num_tokens, int aligned_context_window)
void gemm_q6_k_q8_k(float *Y, const void *W, const void *X_q8, int M, int N, int K)
GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K.
void gemv_q4_k_q8_k_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
void gemm_nt_q5_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void rope_backward_inplace(float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:345
void layernorm_naive_serial(const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void fc1_backward_kernel(const float *d_output, const float *fc1_input, const float *W_fc1, float *d_input, float *d_W_fc1, float *d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads)
Definition: mlp_kernels.c:167
void gemm_nt_q6_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void swiglu_backward_exact(const float *input, const float *d_output, float *d_input, int tokens, int dim)
void swiglu_backward_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, int tokens, int dim)
void mlp_token_parallel_bf16_fp32act(const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_input_f, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16)
void gelu_backward_exact_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input)
void gemv_q6_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
GEMV: y = W @ x where W is Q6_K and x is Q8_K.
void embedding_backward_bf16(const int32_t *token_ids, int token_count, const uint16_t *d_output, uint16_t *d_token_embeddings, uint16_t *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void gemv_q5_k(float *y, const void *W, const float *x, int M, int K)
void embedding_backward(const int32_t *token_ids, int token_count, const float *d_output, float *d_token_embeddings, float *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
int ck_get_physical_cores(void)
void mlp_token_parallel_exact(const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads)
Definition: mlp_kernels.c:76
void geglu_forward_fp32(const float *x, float *out, int tokens, int dim)
Definition: gelu_kernels.c:623
void relu_forward_bf16(const uint16_t *input, uint16_t *output, size_t n)
void attention_forward_decode_head_major_gqa_flash(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
void gemm_tn_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:499
void gemm_nt_q8_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void relu_forward(const float *input, float *output, size_t n)
Definition: relu_kernels.c:26
void im2patch(const float *image, float *patches, int C, int H, int W, int P)
void fused_rmsnorm_qkv_prefill(const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill.
void quantize_batch_q8_k(const float *x, void *y, int num_rows, int k)
Batch quantize FP32 to Q8_K format (row-major output)
void vec_dot_q6_k_q8_k(int n, float *s, const void *vx, const void *vy)
Q6_K x Q8_K dot product (single row)
void fc2_backward_kernel(const float *d_output, const float *fc2_input, const float *W_fc2, float *d_input, float *d_W_fc2, float *d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads)
Definition: mlp_kernels.c:118
void quantize_row_q8_k(const float *x, void *y, int k)
void axpy_2d_f32(float *Y, const float *X, float alpha, int num_tokens, int dim, int y_stride, int x_stride)
Batched AXPY for 2D tensors: Y[t,:] += alpha * X[t,:].
Definition: axpy_kernels.c:221
void ck_set_strict_parity(int enabled)
void rmsnorm_backward_int8(const int8_t *d_output, const int8_t *input, const float *gamma, const float *rstd_cache, int8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
void layernorm_backward_kernel(const float *d_output, const float *input, const float *gamma, const float *mean, const float *rstd, float *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim)
void sigmoid_forward_bf16(const uint16_t *input, uint16_t *output, size_t n, float *scratch_input, float *scratch_output)
void gemm_blocked_serial_bf16(const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K)
void sigmoid_backward_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input)
void gemm_nn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:339
void causal_softmax_head_major(float *scores, int num_heads, int num_tokens, int aligned_context_window)
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!
int argmax_f32(const float *scores, int n)
Find index of maximum value.
Definition: topk_kernels.c:226
void unfused_rmsnorm_qkv_prefill(const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *x_norm, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps)
Unfused version for benchmarking comparison.
void add_forward_bf16(const uint16_t *a, const uint16_t *b, uint16_t *y, size_t n)
void add_forward_2d_bf16(const uint16_t *a, const uint16_t *b, uint16_t *y, int tokens, int dim, int aligned_dim)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void gemm_avx512_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:149
void gemv_q4_k(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV based on available SIMD.
void gelu_fast_inplace_bf16(uint16_t *data, size_t n, float *scratch)
void dequant_q8_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q8_0 row (multiple blocks)
void rope_forward_bf16(uint16_t *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch)
void gemm_nt_q5_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_1 weights: C = A @ B^T.
void layernorm_backward_kernel_bf16(const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *mean, const float *rstd, uint16_t *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
void gemv_q6_k_q8_k_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel reference GEMV for Q6_K × Q8_K.
float sigmoid_scalar(float x)
void attention_forward_causal_head_major_gqa_flash(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)
void ck_attention_flash_decode_wrapper(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
Wrapper to call TRUE flash attention from orchestration layer.
void sigmoid_backward(const float *input, const float *d_output, float *d_input, size_t n)
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:448
size_t fused_rmsnorm_qkv_scratch_size(int hidden)
Get scratch buffer size for fused_rmsnorm_qkv_prefill.
void gemv_q8_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q8_0 weights and Q8_0 input.
void gemm_nt_q8_0_q8_0(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
void gemm_fine_grained_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:205
void relu_forward_inplace(float *data, size_t n)
Definition: relu_kernels.c:54
void gelu_fast_inplace(float *data, size_t n)
Definition: gelu_kernels.c:132
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q5_0_parallel(float *y, const void *W, const float *x, int M, int K, int ith, int nth)
Parallel reference GEMV for Q5_0 × FP32.
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void geglu_forward_bf16(const uint16_t *x, uint16_t *out, int tokens, int dim, float *scratch)
Definition: gelu_kernels.c:813
void gemm_q4_k(float *Y, const void *W, const float *X, int M, int N, int K)
Auto-dispatch GEMM based on available SIMD.
void gelu_backward_scalar(const float *input, const float *d_output, float *d_input, size_t n)
Definition: gelu_kernels.c:462
void rmsnorm_forward_int4(const uint8_t *input, const float *gamma, uint8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
void gemm_q4_k_q8_k(float *Y, const void *W, const void *X_q8, int M, int N, int K)
void gemm_microkernel_blocked(const float *A, const float *B, float *C, int M, int N, int K)
void dequant_q5_1_row(const void *src, float *dst, size_t n_elements)
Dequantize Q5_1 row (multiple blocks)
void fused_mlp_swiglu_decode_v2(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
void gemm_nt_q6_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
void dequant_q4_1_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_1 row (multiple blocks)
void gemv_fused_q5_0_bias_dispatch(float *y, const void *W, const float *x, const float *bias, int M, int K)
void fused_mlp_swiglu_prefill_bias(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *B_gate, const float *B_up, const float *B_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
Fused MLP (Gate + Up + SwiGLU + Down) for prefill with biases.
void layernorm_forward_rolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps)
void embedding_forward_q8_0(const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
int ck_flash_attn_fast_exp_kind(void)
void attention_backward_causal_head_major(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void embedding_forward(const int32_t *token_ids, int token_count, int vocab_size, const float *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
void layernorm_forward_unrolled_slice_bf16(const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps, float *scratch_input, float *scratch_output)
void gemm_bias_gelu_fused(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void fused_mlp_swiglu_prefill_w1w2_quant(const float *x, const void *W1, const float *B1, CKDataType w1_dt, const void *W2, const float *B2, CKDataType w2_dt, float *output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void *scratch)
Quantized fused MLP for prefill (W1=gate+up, W2=down)
void rope_forward_qk_bf16(uint16_t *q, uint16_t *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_q, float *scratch_k)
void attention_backward_causal_head_major_gqa(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void topk_f32(const float *scores, int n, int k, int *indices, float *values)
Find top-K indices and values from a score vector.
Definition: topk_kernels.c:49
void gemm_microkernel_blocked_bt(const float *A, const float *B, float *C, int M, int N, int K)
void dequant_q6_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q6_K row (multiple blocks)
void gemv_q6_k_q8_k_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q6_K × Q8_K.
void gemm_tn_blocked(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:581
void topk_softmax_f32(const float *scores, int n, int k, int *indices, float *weights)
Find top-K indices with softmax-normalized weights.
Definition: topk_kernels.c:134
void gemv_q5_0_parallel_simd(float *y, const void *W, const float *x, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q5_0 × FP32 with prefetching.
void kv_cache_store(float *__restrict kv_cache_k, float *__restrict kv_cache_v, const float *__restrict k, const float *__restrict v, int layer, int pos, int num_kv_heads, int head_dim, int max_seq_len)
void gelu_backward_fast(const float *input, const float *d_output, float *d_input, size_t n)
Definition: gelu_kernels.c:486
void softmax_cross_entropy_loss(const float *logits, const int32_t *targets, int tokens, int vocab_size, float *d_logits, float *loss_out)
Definition: loss_kernels.c:21
void causal_softmax_head_major_bf16(uint16_t *scores, int num_heads, int num_tokens, int aligned_context_window, float *scratch)
int ck_get_num_threads(void)
void patch2im_bf16(const uint16_t *d_patches, uint16_t *d_image, int C, int H, int W, int P)
int ck_strict_parity_enabled(void)
void add_backward_bf16(const uint16_t *d_y, uint16_t *d_a, uint16_t *d_b, size_t n)
void gemm_nt_q5_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void rope_backward(const float *d_out, float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:238
void dequant_q4_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_K row (multiple blocks)
void scal_copy_f32(float *y, const float *x, float alpha, int n)
Scaled copy: y = alpha * x.
Definition: axpy_kernels.c:105
void gemm_microkernel_packed(const float *A, const float *B, float *C, int M, int N, int K)
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:661
void rope_forward_strided(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int head_stride_tokens)
Definition: rope_kernels.c:207
size_t fused_mlp_swiglu_scratch_size(int intermediate)
Get scratch buffer size for fused_mlp_swiglu_prefill.
void sigmoid_forward(const float *input, float *output, size_t n)
void layernorm_forward_rolled_slice_bf16(const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
void relu_backward_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n)
void quantize_batch_q8_0(const float *x, void *y, int num_rows, int k)
Batch quantize FP32 to Q8_0 format (row-major output)
void gemv_q4_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.
void gemv_q4_k_q8_k_parallel_simd(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
void layernorm_forward_unrolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)
void gemm_nn_blocked(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:402
void add_scaled_inplace_bf16(uint16_t *a, const uint16_t *b, float alpha, size_t n)
void fused_rmsnorm_qkv_prefill_head_major_quant(const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
void gemm_tn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:521
Quantization block structures for weight-only quantization.
Mega-Fused Attention Kernel.
#define C(color)
Definition: show_config.c:39
int vocab_size
Definition: true_bpe.h:185