← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_orchestration.c File Reference
#include "ckernel_orchestration.h"
#include "ckernel_engine.h"
#include "ckernel_dtype.h"
#include "ckernel_quant.h"
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>

Go to the source code of this file.

Functions

static void ck_add_inplace (float *dst, const float *src, int tokens, int aligned_embed_dim)
 
void ck_attention_flash_decode_wrapper (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
 Wrapper to call TRUE flash attention from orchestration layer. More...
 
void ck_attention_project_head_major (const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
void ck_attention_project_head_major_backward (const float *d_out, const float *attn_out, const float *wo, float *d_attn_out, float *d_wo, float *d_bo, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
static void ck_attention_project_head_major_q4_k (const float *attn_out, const void *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
static void ck_attention_project_head_major_q4_k_q8_k (const float *attn_out, const void *wo, const float *bo, float *out, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
static void ck_attention_project_head_major_quant (const float *attn_out, const void *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, CKDataType wo_dtype)
 
static void ck_attention_project_head_major_ref (const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
static void ck_debug_check_buffer (const char *stage, const float *buf, int size)
 
static void ck_debug_check_q4k_weights (const char *stage, const void *q4_buf, int num_blocks)
 
static void ck_debug_check_q8k (const char *stage, const void *q8_buf, int num_blocks)
 
void ck_gemm_nt_quant (const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype)
 
void ck_layer_backward_rmsnorm_swiglu (const CKLayerBackwardParams *p)
 
static int ck_layer_debug_enabled (void)
 
void ck_layer_forward_rmsnorm_swiglu (const CKLayerForwardParams *p)
 
void ck_layer_forward_rmsnorm_swiglu_decode (const CKLayerForwardParams *p, int token_index, int cache_capacity)
 
void ck_layer_forward_rmsnorm_swiglu_decode_fused (const CKLayerForwardParams *p, int token_index, int cache_capacity)
 
void ck_layer_forward_rmsnorm_swiglu_decode_q4_k (const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity)
 
void ck_layer_forward_rmsnorm_swiglu_decode_quant (const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity)
 
void ck_layer_forward_rmsnorm_swiglu_q4_k (const CKLayerForwardParamsQ4K *p)
 
void ck_layer_forward_rmsnorm_swiglu_quant (const CKLayerForwardParamsQ4K *p)
 
void ck_layer_forward_rmsnorm_swiglu_ref (const CKLayerForwardParams *p)
 
void ck_mlp_swiglu_forward (const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
 
void ck_mlp_swiglu_forward_fully_fused_token (const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
 
void ck_mlp_swiglu_forward_fused_token (const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *swiglu_row, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
 
static void ck_mlp_swiglu_forward_q4_k (const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
 
static void ck_mlp_swiglu_forward_q4_k_q8_k (const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int aligned_embed_dim, int aligned_intermediate_dim)
 
static void ck_mlp_swiglu_forward_q4_k_q8_k_prefill (const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
 
static void ck_mlp_swiglu_forward_quant (const float *input, const void *w1, const float *b1, CKDataType w1_dtype, const void *w2, const float *b2, CKDataType w2_dtype, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
 
static void ck_mlp_swiglu_forward_ref (const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
 
static int ck_q8k_activations_enabled (void)
 
void ck_qkv_project_head_major (const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
 
void ck_qkv_project_head_major_backward (const float *d_q, const float *d_k, const float *d_v, const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *d_input, float *d_wq, float *d_bq, float *d_wk, float *d_bk, float *d_wv, float *d_bv, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim, int num_threads)
 
static void ck_qkv_project_head_major_q4_k (const float *input, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
 
static void ck_qkv_project_head_major_q4_k_q8_k (const float *input, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
 
static void ck_qkv_project_head_major_quant (const float *input, const void *wq, const float *bq, CKDataType wq_dtype, const void *wk, const float *bk, CKDataType wk_dtype, const void *wv, const float *bv, CKDataType wv_dtype, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
 
static void ck_qkv_project_head_major_ref (const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
 
static void ck_qkv_project_head_major_token_q4_k (const float *input_row, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
 
static void ck_qkv_project_head_major_token_q4_k_q8_k (const block_q8_K *input_q8, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
 
static void ck_qkv_project_head_major_token_quant (const float *input_row, const void *wq, const float *bq, CKDataType wq_dtype, const void *wk, const float *bk, CKDataType wk_dtype, const void *wv, const float *bv, CKDataType wv_dtype, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
 
void ck_residual_add_backward (const float *d_out, float *d_a, float *d_b, int tokens, int aligned_embed_dim)
 
void ck_residual_add_token_major (const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
 

Detailed Description


LEGACY CODE - NOT USED IN v6.6

This file contains v6.5 orchestration code that is NO LONGER USED. It is kept for reference and potential future use but is NOT compiled into the v6.6 engine.

v6.6 Architecture:

  • IR Lower 3 handles all orchestration via dataflow graph
  • Kernel dispatch via ckernel_codegen.c (for dynamically loaded kernels)
  • Memory planning via memory_planner_v6_6.py

Contents of this file (NOT used):

  • ck_attention_flash_decode_wrapper: Flash attention wrapper (use mega_fused_attention_prefill/avx instead)
  • ck_quantized_gemm: Dispatcher for Q4_K, Q5_0, Q5_1, Q6_K, Q8_0 (use kernel_maps/KERNEL_REGISTRY.json + codegen instead)

To remove completely:

  1. Delete this file
  2. Remove from Makefile SRCS list
  3. Remove ckernel_orchestration.h

Last used: v6.5

Deprecated: v6.6 (2026-02)

Definition in file ckernel_orchestration.c.

Function Documentation

◆ ck_add_inplace()

static void ck_add_inplace ( float *  dst,
const float *  src,
int  tokens,
int  aligned_embed_dim 
)
static

Definition at line 719 of file ckernel_orchestration.c.

723 {
724  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
725  for (size_t i = 0; i < total; ++i) {
726  dst[i] += src[i];
727  }
728 }

Referenced by ck_attention_project_head_major(), ck_attention_project_head_major_ref(), ck_layer_backward_rmsnorm_swiglu(), and ck_qkv_project_head_major_backward().

◆ ck_attention_flash_decode_wrapper()

void ck_attention_flash_decode_wrapper ( const float *  q_token,
const float *  k_cache,
const float *  v_cache,
float *  out_token,
int  num_heads,
int  num_kv_heads,
int  kv_tokens,
int  cache_capacity,
int  head_dim,
int  aligned_head_dim 
)

Wrapper to call TRUE flash attention from orchestration layer.

Parameters
q_tokenQuery token [H, D_h]
k_cacheCached keys [T_k, H, D_h]
v_cacheCached values [T_k, H, D_h]
out_tokenOutput [H, D_h]
num_headsNumber of heads
num_kv_headsNumber of KV heads (for GQA)
kv_tokensNumber of tokens in KV cache
cache_capacityCache capacity
head_dimHead dimension
aligned_head_dimAligned head dimension

Definition at line 72 of file ckernel_orchestration.c.

83 {
84  if (!q_token || !k_cache || !v_cache || !out_token) {
85  return;
86  }
87  if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
88  return;
89  }
90  if (kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
91  return;
92  }
93 
94  static int use_strict = -1;
95  if (use_strict < 0) {
96  const char *env = getenv("CK_FLASH_ATTN_STRICT");
97  use_strict = (env && env[0] && env[0] != '0') ? 1 : 0;
98  }
99 
100  if (use_strict) {
102  k_cache,
103  v_cache,
104  out_token,
105  num_heads,
106  num_kv_heads,
107  kv_tokens,
108  cache_capacity,
109  head_dim,
110  aligned_head_dim);
111  return;
112  }
113 
114  // Scale factor: 1/sqrt(head_dim)
115  const float scale = 1.0f / sqrtf((float)head_dim);
116  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
117 
118 #pragma omp parallel for schedule(static) if(num_heads > 1)
119  for (int h = 0; h < num_heads; ++h) {
120  const int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
121  const float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
122  const float *k_head = k_cache + (size_t)kv_head * head_stride;
123  const float *v_head = v_cache + (size_t)kv_head * head_stride;
124  float *out_head = out_token + (size_t)h * (size_t)aligned_head_dim;
125 
126  // Use aligned_head_dim as D_h so per-token stride matches the cache layout.
127  attention_flash_decode(out_head,
128  q_head,
129  k_head,
130  v_head,
131  1,
132  kv_tokens,
133  1,
134  aligned_head_dim,
135  scale);
136  }
137 }
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!

References attention_flash_decode(), and attention_forward_decode_head_major_gqa_regular().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), and ck_layer_forward_rmsnorm_swiglu_decode_quant().

◆ ck_attention_project_head_major()

void ck_attention_project_head_major ( const float *  attn_out,
const float *  wo,
const float *  bo,
float *  out,
float *  scratch,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)

Definition at line 730 of file ckernel_orchestration.c.

739 {
740  if (!attn_out || !wo || !out) {
741  return;
742  }
743  if (num_heads > 1 && !scratch) {
744  return;
745  }
746 
747  size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
748  size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
749 
750  for (int h = 0; h < num_heads; ++h) {
751  const float *head_in = attn_out + (size_t)h * head_in_stride;
752  const float *wo_h = wo + (size_t)h * head_weight_stride;
753 
754  if (h == 0) {
755  gemm_blocked_serial(head_in, wo_h, bo, out,
756  tokens, aligned_embed_dim, aligned_head_dim);
757  } else {
758  gemm_blocked_serial(head_in, wo_h, NULL, scratch,
759  tokens, aligned_embed_dim, aligned_head_dim);
760  ck_add_inplace(out, scratch, tokens, aligned_embed_dim);
761  }
762  }
763 }
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:661
static void ck_add_inplace(float *dst, const float *src, int tokens, int aligned_embed_dim)

References ck_add_inplace(), and gemm_blocked_serial().

Referenced by ck_attention_project_head_major_quant(), and ck_layer_forward_rmsnorm_swiglu().

◆ ck_attention_project_head_major_backward()

void ck_attention_project_head_major_backward ( const float *  d_out,
const float *  attn_out,
const float *  wo,
float *  d_attn_out,
float *  d_wo,
float *  d_bo,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)

Definition at line 800 of file ckernel_orchestration.c.

810 {
811  if (!d_out || !attn_out || !wo || !d_attn_out || !d_wo || !d_bo) {
812  return;
813  }
814 
815  // Bias gradient: sum over tokens once (bias is applied once in forward).
816  for (int d = 0; d < aligned_embed_dim; ++d) {
817  d_bo[d] = 0.0f;
818  }
819  for (int t = 0; t < tokens; ++t) {
820  const float *row = d_out + (size_t)t * (size_t)aligned_embed_dim;
821  for (int d = 0; d < aligned_embed_dim; ++d) {
822  d_bo[d] += row[d];
823  }
824  }
825 
826  size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
827  size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
828 
829  float *tmp_b = (float *)calloc((size_t)aligned_embed_dim, sizeof(float));
830  if (!tmp_b) {
831  return;
832  }
833 
834  for (int h = 0; h < num_heads; ++h) {
835  const float *head_in = attn_out + (size_t)h * head_in_stride;
836  const float *wo_h = wo + (size_t)h * head_weight_stride;
837  float *d_head_in = d_attn_out + (size_t)h * head_in_stride;
838  float *d_wo_h = d_wo + (size_t)h * head_weight_stride;
839 
840  memset(tmp_b, 0, (size_t)aligned_embed_dim * sizeof(float));
841  fc2_backward_kernel(d_out,
842  head_in,
843  wo_h,
844  d_head_in,
845  d_wo_h,
846  tmp_b,
847  tokens,
848  aligned_head_dim,
849  aligned_embed_dim,
850  1);
851  }
852 
853  free(tmp_b);
854 }
void fc2_backward_kernel(const float *d_output, const float *fc2_input, const float *W_fc2, float *d_input, float *d_W_fc2, float *d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads)
Definition: mlp_kernels.c:118

References fc2_backward_kernel().

Referenced by ck_layer_backward_rmsnorm_swiglu().

◆ ck_attention_project_head_major_q4_k()

static void ck_attention_project_head_major_q4_k ( const float *  attn_out,
const void *  wo,
const float *  bo,
float *  out,
float *  scratch,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)
static

Definition at line 504 of file ckernel_orchestration.c.

513 {
514  if (!attn_out || !wo || !out || !scratch) {
515  return;
516  }
517 
518  /* Flatten head-major [H, T, ad] into token-major [T, H*ad] */
519  const int K = num_heads * aligned_head_dim;
520  if (K != aligned_embed_dim) {
521  return;
522  }
523 
524  const size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
525 
526  for (int t = 0; t < tokens; ++t) {
527  float *dst = scratch + (size_t)t * (size_t)aligned_embed_dim;
528  for (int h = 0; h < num_heads; ++h) {
529  const float *src = attn_out + (size_t)h * head_in_stride + (size_t)t * (size_t)aligned_head_dim;
530  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
531  src,
532  (size_t)aligned_head_dim * sizeof(float));
533  }
534  }
535 
536  gemm_nt_q4_k(scratch, wo, bo, out,
537  tokens, aligned_embed_dim, aligned_embed_dim);
538 }
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)

References gemm_nt_q4_k().

Referenced by ck_layer_forward_rmsnorm_swiglu_q4_k().

◆ ck_attention_project_head_major_q4_k_q8_k()

static void ck_attention_project_head_major_q4_k_q8_k ( const float *  attn_out,
const void *  wo,
const float *  bo,
float *  out,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)
static

Definition at line 1764 of file ckernel_orchestration.c.

1772 {
1773  if (!attn_out || !wo || !out) {
1774  return;
1775  }
1776  if (tokens <= 0 || aligned_embed_dim <= 0) {
1777  return;
1778  }
1779  if ((aligned_embed_dim % QK_K) != 0) {
1780  return;
1781  }
1782 
1783  const int K = num_heads * aligned_head_dim;
1784  if (K != aligned_embed_dim) {
1785  return;
1786  }
1787 
1788  const int q8_blocks = aligned_embed_dim / QK_K;
1789  block_q8_K q8_buf[q8_blocks];
1790  float attn_token[aligned_embed_dim];
1791  const size_t head_stride = (size_t)tokens * (size_t)aligned_head_dim;
1792 
1793  for (int t = 0; t < tokens; ++t) {
1794  for (int h = 0; h < num_heads; ++h) {
1795  const float *src = attn_out + (size_t)h * head_stride + (size_t)t * (size_t)aligned_head_dim;
1796  memcpy(attn_token + (size_t)h * (size_t)aligned_head_dim,
1797  src,
1798  (size_t)aligned_head_dim * sizeof(float));
1799  }
1800 
1801  quantize_row_q8_k(attn_token, q8_buf, aligned_embed_dim);
1802  gemm_nt_q4_k_q8_k(q8_buf, wo, bo,
1803  out + (size_t)t * (size_t)aligned_embed_dim,
1804  /*M=*/1, /*N=*/aligned_embed_dim, /*K=*/aligned_embed_dim);
1805  }
1806 }
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void quantize_row_q8_k(const float *x, void *y, int k)
#define QK_K

References gemm_nt_q4_k_q8_k(), QK_K, and quantize_row_q8_k().

Referenced by ck_layer_forward_rmsnorm_swiglu_q4_k().

◆ ck_attention_project_head_major_quant()

static void ck_attention_project_head_major_quant ( const float *  attn_out,
const void *  wo,
const float *  bo,
float *  out,
float *  scratch,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim,
CKDataType  wo_dtype 
)
static

Definition at line 540 of file ckernel_orchestration.c.

550 {
551  if (!attn_out || !wo || !out || !scratch) {
552  return;
553  }
554 
555  if (wo_dtype == CK_DT_FP32) {
557  (const float *)wo,
558  bo,
559  out,
560  scratch,
561  tokens,
562  aligned_embed_dim,
563  num_heads,
564  aligned_head_dim);
565  return;
566  }
567 
568  const int K = num_heads * aligned_head_dim;
569  if (K != aligned_embed_dim) {
570  return;
571  }
572 
573  const size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
574 
575  for (int t = 0; t < tokens; ++t) {
576  float *dst = scratch + (size_t)t * (size_t)aligned_embed_dim;
577  for (int h = 0; h < num_heads; ++h) {
578  const float *src = attn_out + (size_t)h * head_in_stride + (size_t)t * (size_t)aligned_head_dim;
579  memcpy(dst + (size_t)h * (size_t)aligned_head_dim,
580  src,
581  (size_t)aligned_head_dim * sizeof(float));
582  }
583  }
584 
585  ck_gemm_nt_quant(scratch, wo, bo, out,
586  tokens, aligned_embed_dim, aligned_embed_dim, wo_dtype);
587 }
@ CK_DT_FP32
Definition: ckernel_dtype.h:29
void ck_attention_project_head_major(const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_gemm_nt_quant(const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype)

References ck_attention_project_head_major(), CK_DT_FP32, and ck_gemm_nt_quant().

Referenced by ck_layer_forward_rmsnorm_swiglu_quant().

◆ ck_attention_project_head_major_ref()

static void ck_attention_project_head_major_ref ( const float *  attn_out,
const float *  wo,
const float *  bo,
float *  out,
float *  scratch,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)
static

Definition at line 765 of file ckernel_orchestration.c.

774 {
775  if (!attn_out || !wo || !out) {
776  return;
777  }
778  if (num_heads > 1 && !scratch) {
779  return;
780  }
781 
782  size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
783  size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
784 
785  for (int h = 0; h < num_heads; ++h) {
786  const float *head_in = attn_out + (size_t)h * head_in_stride;
787  const float *wo_h = wo + (size_t)h * head_weight_stride;
788 
789  if (h == 0) {
790  gemm_naive_parallel(head_in, wo_h, bo, out,
791  tokens, aligned_embed_dim, aligned_head_dim);
792  } else {
793  gemm_naive_parallel(head_in, wo_h, NULL, scratch,
794  tokens, aligned_embed_dim, aligned_head_dim);
795  ck_add_inplace(out, scratch, tokens, aligned_embed_dim);
796  }
797  }
798 }
void gemm_naive_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:125

References ck_add_inplace(), and gemm_naive_parallel().

Referenced by ck_layer_forward_rmsnorm_swiglu_ref().

◆ ck_debug_check_buffer()

static void ck_debug_check_buffer ( const char *  stage,
const float *  buf,
int  size 
)
static

Definition at line 232 of file ckernel_orchestration.c.

233 {
234  if (!ck_layer_debug_enabled() || !buf) {
235  return;
236  }
237  int nan_count = 0, inf_count = 0;
238  float min_val = 1e38f, max_val = -1e38f;
239  for (int i = 0; i < size; ++i) {
240  float v = buf[i];
241  if (isnan(v)) {
242  nan_count++;
243  } else if (isinf(v)) {
244  inf_count++;
245  } else {
246  if (v < min_val) min_val = v;
247  if (v > max_val) max_val = v;
248  }
249  }
250  if (nan_count > 0 || inf_count > 0) {
251  fprintf(stderr, "[LAYER_DEBUG] %-30s size=%5d nan=%d inf=%d\n",
252  stage, size, nan_count, inf_count);
253  } else {
254  fprintf(stderr, "[LAYER_DEBUG] %-30s size=%5d range=[%.3e, %.3e]\n",
255  stage, size, min_val, max_val);
256  }
257 }
static int ck_layer_debug_enabled(void)

References ck_layer_debug_enabled().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_q4_k().

◆ ck_debug_check_q4k_weights()

static void ck_debug_check_q4k_weights ( const char *  stage,
const void *  q4_buf,
int  num_blocks 
)
static

Definition at line 287 of file ckernel_orchestration.c.

288 {
289  if (!ck_layer_debug_enabled() || !q4_buf) {
290  return;
291  }
292  const block_q4_K *blocks = (const block_q4_K *)q4_buf;
293  int nan_d = 0, nan_dmin = 0;
294  float min_d = 1e38f, max_d = -1e38f;
295  for (int i = 0; i < num_blocks; ++i) {
296  float d = CK_FP16_TO_FP32(blocks[i].d);
297  float dm = CK_FP16_TO_FP32(blocks[i].dmin);
298  if (isnan(d)) nan_d++;
299  if (isnan(dm)) nan_dmin++;
300  if (!isnan(d) && !isinf(d)) {
301  if (d < min_d) min_d = d;
302  if (d > max_d) max_d = d;
303  }
304  }
305  if (nan_d > 0 || nan_dmin > 0) {
306  fprintf(stderr, "[LAYER_DEBUG] %-30s blocks=%d nan_d=%d nan_dmin=%d\n",
307  stage, num_blocks, nan_d, nan_dmin);
308  } else {
309  fprintf(stderr, "[LAYER_DEBUG] %-30s blocks=%d d_range=[%.3e, %.3e]\n",
310  stage, num_blocks, min_d, max_d);
311  }
312 }
#define CK_FP16_TO_FP32(x)

References CK_FP16_TO_FP32, and ck_layer_debug_enabled().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_q4_k().

◆ ck_debug_check_q8k()

static void ck_debug_check_q8k ( const char *  stage,
const void *  q8_buf,
int  num_blocks 
)
static

Definition at line 259 of file ckernel_orchestration.c.

260 {
261  if (!ck_layer_debug_enabled() || !q8_buf) {
262  return;
263  }
264  const block_q8_K *blocks = (const block_q8_K *)q8_buf;
265  int nan_scale = 0, inf_scale = 0;
266  float min_d = 1e38f, max_d = -1e38f;
267  for (int i = 0; i < num_blocks; ++i) {
268  float d = blocks[i].d;
269  if (isnan(d)) {
270  nan_scale++;
271  } else if (isinf(d)) {
272  inf_scale++;
273  } else {
274  if (d < min_d) min_d = d;
275  if (d > max_d) max_d = d;
276  }
277  }
278  if (nan_scale > 0 || inf_scale > 0) {
279  fprintf(stderr, "[LAYER_DEBUG] %-30s blocks=%d nan_scale=%d inf_scale=%d\n",
280  stage, num_blocks, nan_scale, inf_scale);
281  } else {
282  fprintf(stderr, "[LAYER_DEBUG] %-30s blocks=%d scale_range=[%.3e, %.3e]\n",
283  stage, num_blocks, min_d, max_d);
284  }
285 }

References ck_layer_debug_enabled(), and block_q8_K::d.

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_q4_k().

◆ ck_gemm_nt_quant()

void ck_gemm_nt_quant ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K,
CKDataType  dtype 
)

Definition at line 335 of file ckernel_orchestration.c.

341 {
342  switch (dtype) {
343  case CK_DT_FP32:
344  gemm_blocked_serial(A, (const float *)B, bias, C, M, N, K);
345  break;
346  case CK_DT_Q4_K:
347  gemm_nt_q4_k(A, B, bias, C, M, N, K);
348  break;
349  case CK_DT_Q6_K:
350  gemm_nt_q6_k(A, B, bias, C, M, N, K);
351  break;
352  case CK_DT_Q4_0:
353  gemm_nt_q4_0(A, B, bias, C, M, N, K);
354  break;
355  case CK_DT_Q4_1:
356  gemm_nt_q4_1(A, B, bias, C, M, N, K);
357  break;
358  case CK_DT_Q5_0:
359  gemm_nt_q5_0(A, B, bias, C, M, N, K);
360  break;
361  case CK_DT_Q5_1:
362  gemm_nt_q5_1(A, B, bias, C, M, N, K);
363  break;
364  case CK_DT_Q8_0:
365  gemm_nt_q8_0(A, B, bias, C, M, N, K);
366  break;
367  default:
368  break;
369  }
370 }
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
@ CK_DT_Q4_0
Definition: ckernel_dtype.h:38
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_Q5_0
Definition: ckernel_dtype.h:44
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
@ CK_DT_Q4_1
Definition: ckernel_dtype.h:39
@ CK_DT_Q5_1
Definition: ckernel_dtype.h:45
void gemm_nt_q4_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void gemm_nt_q4_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q4_1 weights: C = A @ B^T.
void gemm_nt_q5_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q6_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q8_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void gemm_nt_q5_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_1 weights: C = A @ B^T.
#define C(color)
Definition: show_config.c:39

References C, CK_DT_FP32, CK_DT_Q4_0, CK_DT_Q4_1, CK_DT_Q4_K, CK_DT_Q5_0, CK_DT_Q5_1, CK_DT_Q6_K, CK_DT_Q8_0, gemm_blocked_serial(), gemm_nt_q4_0(), gemm_nt_q4_1(), gemm_nt_q4_k(), gemm_nt_q5_0(), gemm_nt_q5_1(), gemm_nt_q6_k(), and gemm_nt_q8_0().

Referenced by ck_attention_project_head_major_quant(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), ck_mlp_swiglu_forward_quant(), ck_qkv_project_head_major_quant(), ck_qkv_project_head_major_token_quant(), and mega_fused_attention_prefill().

◆ ck_layer_backward_rmsnorm_swiglu()

void ck_layer_backward_rmsnorm_swiglu ( const CKLayerBackwardParams p)

Definition at line 2677 of file ckernel_orchestration.c.

2678 {
2679  if (!p) {
2680  return;
2681  }
2682 
2683  int T = p->tokens;
2684  int aligned_embed = p->aligned_embed_dim;
2685  int aligned_head = p->aligned_head_dim;
2686  int aligned_intermediate = p->aligned_intermediate_dim;
2687  int up_dim = 2 * aligned_intermediate;
2688  int num_threads = 1;
2689 
2690  // 1) Residual add (output = residual1 + mlp_out)
2691  ck_residual_add_backward(p->d_output, p->d_residual1, p->d_mlp_out, T, aligned_embed);
2692 
2693  // 2) MLP down proj backward
2695  p->swiglu_out,
2696  p->w2,
2697  p->d_swiglu_out,
2698  p->d_w2,
2699  p->d_b2,
2700  T,
2701  aligned_intermediate,
2702  aligned_embed,
2703  num_threads);
2704 
2705  // 3) SwiGLU backward
2706  swiglu_backward(p->fc1_out, p->d_swiglu_out, p->d_fc1_out, T, aligned_intermediate);
2707 
2708  // 4) MLP up proj backward
2710  p->ln2_out,
2711  p->w1,
2712  p->d_ln2_out,
2713  p->d_w1,
2714  p->d_b1,
2715  T,
2716  aligned_embed,
2717  up_dim,
2718  num_threads);
2719 
2720  // 5) RMSNorm (ln2) backward; reuse d_output as scratch for d_residual1_from_ln2
2722  p->residual1,
2723  p->ln2_gamma,
2724  p->ln2_rstd,
2725  p->d_output,
2726  p->d_ln2_gamma,
2727  T,
2728  p->embed_dim,
2729  aligned_embed);
2730  ck_add_inplace(p->d_residual1, p->d_output, T, aligned_embed);
2731 
2732  // 6) Residual add (residual1 = input + proj_tmp)
2733  ck_residual_add_backward(p->d_residual1, p->d_input, p->d_proj_tmp, T, aligned_embed);
2734 
2735  // 7) Attention projection backward
2737  p->attn_out,
2738  p->wo,
2739  p->d_attn_out,
2740  p->d_wo,
2741  p->d_bo,
2742  T,
2743  aligned_embed,
2744  p->num_heads,
2745  aligned_head);
2746 
2747  // 8) Attention backward
2749  p->q,
2750  p->k,
2751  p->v,
2752  p->scores,
2753  p->d_q,
2754  p->d_k,
2755  p->d_v,
2756  p->d_scores,
2757  p->num_heads,
2758  p->num_kv_heads,
2759  T,
2760  p->head_dim,
2761  aligned_head,
2763 
2764  // 9) RoPE backward (if enabled)
2765  if (p->rope_cos && p->rope_sin) {
2766  rope_backward_qk(p->d_q,
2767  p->d_k,
2768  p->d_q,
2769  p->d_k,
2770  p->rope_cos,
2771  p->rope_sin,
2772  p->num_heads,
2773  p->num_kv_heads,
2774  T,
2775  p->head_dim,
2776  aligned_head,
2777  p->rope_pos_offset);
2778  }
2779 
2780  // 10) QKV projection backward (scratch uses d_proj_tmp)
2782  p->d_k,
2783  p->d_v,
2784  p->ln1_out,
2785  p->wq,
2786  p->bq,
2787  p->wk,
2788  p->bk,
2789  p->wv,
2790  p->bv,
2791  p->d_ln1_out,
2792  p->d_wq,
2793  p->d_bq,
2794  p->d_wk,
2795  p->d_bk,
2796  p->d_wv,
2797  p->d_bv,
2798  p->d_proj_tmp,
2799  T,
2800  aligned_embed,
2801  p->num_heads,
2802  p->num_kv_heads,
2803  aligned_head,
2804  num_threads);
2805 
2806  // 11) RMSNorm (ln1) backward; reuse d_ln1_out as scratch for d_input_from_ln1
2808  p->input,
2809  p->ln1_gamma,
2810  p->ln1_rstd,
2811  p->d_ln1_out,
2812  p->d_ln1_gamma,
2813  T,
2814  p->embed_dim,
2815  aligned_embed);
2816  ck_add_inplace(p->d_input, p->d_ln1_out, T, aligned_embed);
2817 }
void swiglu_backward(const float *input, const float *d_output, float *d_input, int tokens, int dim)
void rope_backward_qk(const float *d_q_out, const float *d_k_out, float *d_q, float *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:497
void fc1_backward_kernel(const float *d_output, const float *fc1_input, const float *W_fc1, float *d_input, float *d_W_fc1, float *d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads)
Definition: mlp_kernels.c:167
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
void attention_backward_causal_head_major_gqa(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void ck_residual_add_backward(const float *d_out, float *d_a, float *d_b, int tokens, int aligned_embed_dim)
void ck_qkv_project_head_major_backward(const float *d_q, const float *d_k, const float *d_v, const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *d_input, float *d_wq, float *d_bq, float *d_wk, float *d_bk, float *d_wv, float *d_bv, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim, int num_threads)
void ck_attention_project_head_major_backward(const float *d_out, const float *attn_out, const float *wo, float *d_attn_out, float *d_wo, float *d_bo, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)

References CKLayerBackwardParams::aligned_context_window, CKLayerBackwardParams::aligned_embed_dim, CKLayerBackwardParams::aligned_head_dim, CKLayerBackwardParams::aligned_intermediate_dim, attention_backward_causal_head_major_gqa(), CKLayerBackwardParams::attn_out, CKLayerBackwardParams::bk, CKLayerBackwardParams::bq, CKLayerBackwardParams::bv, ck_add_inplace(), ck_attention_project_head_major_backward(), ck_qkv_project_head_major_backward(), ck_residual_add_backward(), CKLayerBackwardParams::d_attn_out, CKLayerBackwardParams::d_b1, CKLayerBackwardParams::d_b2, CKLayerBackwardParams::d_bk, CKLayerBackwardParams::d_bo, CKLayerBackwardParams::d_bq, CKLayerBackwardParams::d_bv, CKLayerBackwardParams::d_fc1_out, CKLayerBackwardParams::d_input, CKLayerBackwardParams::d_k, CKLayerBackwardParams::d_ln1_gamma, CKLayerBackwardParams::d_ln1_out, CKLayerBackwardParams::d_ln2_gamma, CKLayerBackwardParams::d_ln2_out, CKLayerBackwardParams::d_mlp_out, CKLayerBackwardParams::d_output, CKLayerBackwardParams::d_proj_tmp, CKLayerBackwardParams::d_q, CKLayerBackwardParams::d_residual1, CKLayerBackwardParams::d_scores, CKLayerBackwardParams::d_swiglu_out, CKLayerBackwardParams::d_v, CKLayerBackwardParams::d_w1, CKLayerBackwardParams::d_w2, CKLayerBackwardParams::d_wk, CKLayerBackwardParams::d_wo, CKLayerBackwardParams::d_wq, CKLayerBackwardParams::d_wv, CKLayerBackwardParams::embed_dim, fc1_backward_kernel(), CKLayerBackwardParams::fc1_out, fc2_backward_kernel(), CKLayerBackwardParams::head_dim, CKLayerBackwardParams::input, CKLayerBackwardParams::k, CKLayerBackwardParams::ln1_gamma, CKLayerBackwardParams::ln1_out, CKLayerBackwardParams::ln1_rstd, CKLayerBackwardParams::ln2_gamma, CKLayerBackwardParams::ln2_out, CKLayerBackwardParams::ln2_rstd, CKLayerBackwardParams::num_heads, CKLayerBackwardParams::num_kv_heads, CKLayerBackwardParams::q, CKLayerBackwardParams::residual1, rmsnorm_backward(), rope_backward_qk(), CKLayerBackwardParams::rope_cos, CKLayerBackwardParams::rope_pos_offset, CKLayerBackwardParams::rope_sin, CKLayerBackwardParams::scores, swiglu_backward(), CKLayerBackwardParams::swiglu_out, CKLayerBackwardParams::tokens, CKLayerBackwardParams::v, CKLayerBackwardParams::w1, CKLayerBackwardParams::w2, CKLayerBackwardParams::wk, CKLayerBackwardParams::wo, CKLayerBackwardParams::wq, and CKLayerBackwardParams::wv.

◆ ck_layer_debug_enabled()

static int ck_layer_debug_enabled ( void  )
static

Definition at line 217 of file ckernel_orchestration.c.

218 {
219  static int cached = -2;
220  if (cached != -2) {
221  return cached;
222  }
223  const char *env = getenv("CK_LAYER_DEBUG");
224  if (env && (env[0] == '1' || env[0] == 'y' || env[0] == 'Y')) {
225  cached = 1;
226  } else {
227  cached = 0;
228  }
229  return cached;
230 }

Referenced by ck_debug_check_buffer(), ck_debug_check_q4k_weights(), and ck_debug_check_q8k().

◆ ck_layer_forward_rmsnorm_swiglu()

void ck_layer_forward_rmsnorm_swiglu ( const CKLayerForwardParams p)

Definition at line 996 of file ckernel_orchestration.c.

997 {
998  if (!p) {
999  return;
1000  }
1001 
1003  p->ln1_gamma,
1004  p->ln1_out,
1005  p->ln1_rstd,
1006  p->tokens,
1007  p->embed_dim,
1008  p->aligned_embed_dim,
1009  p->eps);
1010 
1012  p->wq, p->bq,
1013  p->wk, p->bk,
1014  p->wv, p->bv,
1015  p->q, p->k, p->v,
1016  p->tokens,
1017  p->tokens,
1018  p->aligned_embed_dim,
1019  p->num_heads,
1020  p->num_kv_heads,
1021  p->aligned_head_dim);
1022 
1023  if (p->rope_cos && p->rope_sin) {
1024  rope_forward_qk(p->q,
1025  p->k,
1026  p->rope_cos,
1027  p->rope_sin,
1028  p->num_heads,
1029  p->num_kv_heads,
1030  p->tokens,
1031  p->head_dim,
1032  p->aligned_head_dim,
1033  p->rope_pos_offset);
1034  }
1035 
1036  if (p->scores) {
1038  p->k,
1039  p->v,
1040  p->scores,
1041  p->attn_out,
1042  p->num_heads,
1043  p->num_kv_heads,
1044  p->tokens,
1045  p->head_dim,
1046  p->aligned_head_dim,
1048  } else {
1050  p->k,
1051  p->v,
1052  p->attn_out,
1053  p->num_heads,
1054  p->num_kv_heads,
1055  p->tokens,
1056  p->head_dim,
1057  p->aligned_head_dim);
1058  }
1059 
1061  p->wo,
1062  p->bo,
1063  p->proj_tmp,
1064  p->proj_scratch,
1065  p->tokens,
1066  p->aligned_embed_dim,
1067  p->num_heads,
1068  p->aligned_head_dim);
1069 
1071  p->proj_tmp,
1072  p->residual1,
1073  p->tokens,
1074  p->aligned_embed_dim);
1075 
1077  p->ln2_gamma,
1078  p->ln2_out,
1079  p->ln2_rstd,
1080  p->tokens,
1081  p->embed_dim,
1082  p->aligned_embed_dim,
1083  p->eps);
1084 
1086  p->w1,
1087  p->b1,
1088  p->w2,
1089  p->b2,
1090  p->fc1_out,
1091  p->swiglu_out,
1092  p->mlp_out,
1093  p->tokens,
1094  p->aligned_embed_dim,
1096 
1098  p->mlp_out,
1099  p->output,
1100  p->tokens,
1101  p->aligned_embed_dim);
1102 }
void attention_forward_causal_head_major_gqa(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void attention_forward_causal_head_major_gqa_flash(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:448
void ck_mlp_swiglu_forward(const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_qkv_project_head_major(const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)

References CKLayerForwardParams::aligned_context_window, CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_flash(), CKLayerForwardParams::attn_out, CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_project_head_major(), ck_mlp_swiglu_forward(), ck_qkv_project_head_major(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::fc1_out, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln1_out, CKLayerForwardParams::ln1_rstd, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::proj_scratch, CKLayerForwardParams::proj_tmp, CKLayerForwardParams::q, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::scores, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::tokens, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.

◆ ck_layer_forward_rmsnorm_swiglu_decode()

void ck_layer_forward_rmsnorm_swiglu_decode ( const CKLayerForwardParams p,
int  token_index,
int  cache_capacity 
)

Definition at line 1289 of file ckernel_orchestration.c.

1292 {
1293  if (!p) {
1294  return;
1295  }
1296  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->ln1_out || !p->ln2_out ||
1297  !p->wq || !p->wk || !p->wv || !p->wo || !p->w1 || !p->w2 ||
1298  !p->k || !p->v ||
1299  !p->proj_tmp || !p->residual1 || !p->fc1_out || !p->swiglu_out || !p->mlp_out || !p->output) {
1300  return;
1301  }
1302  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
1303  return;
1304  }
1305  if (p->num_heads <= 0 || p->num_kv_heads <= 0 || p->aligned_head_dim <= 0) {
1306  return;
1307  }
1308 
1309  const int D = p->embed_dim;
1310  const int aligned_D = p->aligned_embed_dim;
1311  const int H = p->num_heads;
1312  const int H_kv = p->num_kv_heads;
1313  const int hd = p->head_dim;
1314  const int ad = p->aligned_head_dim;
1315  const int aligned_intermediate = p->aligned_intermediate_dim;
1316 
1317  /* Decode buffers are single-token; token_index only applies to KV cache. */
1318  const size_t token_slot = 0;
1319  const float *input_row = p->input + token_slot * (size_t)aligned_D;
1320  float *ln1_row = p->ln1_out + token_slot * (size_t)aligned_D;
1321  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
1322  float *proj_row = p->proj_tmp + token_slot * (size_t)aligned_D;
1323  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
1324  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
1325  float *out_row = p->output + token_slot * (size_t)aligned_D;
1326 
1327  float ln1_rstd_tmp = 0.0f;
1328  float ln2_rstd_tmp = 0.0f;
1329  float *ln1_rstd = p->ln1_rstd ? (p->ln1_rstd + token_slot) : &ln1_rstd_tmp;
1330  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
1331 
1332  // Scratch for a single token in head-major layout: [head, aligned_head_dim].
1333  size_t q_elems = (size_t)H * (size_t)ad;
1334  size_t kv_elems = (size_t)H_kv * (size_t)ad;
1335  float q_token[q_elems];
1336  float k_token[kv_elems];
1337  float v_token[kv_elems];
1338  float attn_token[q_elems];
1339 
1340  // LN1 / RMSNorm.
1341  rmsnorm_forward(input_row,
1342  p->ln1_gamma,
1343  ln1_row,
1344  ln1_rstd,
1345  /*tokens=*/1,
1346  D,
1347  aligned_D,
1348  p->eps);
1349 
1350  // Project Q/K/V for the new token.
1352  p->wq, p->bq,
1353  p->wk, p->bk,
1354  p->wv, p->bv,
1355  q_token, k_token, v_token,
1356  aligned_D,
1357  H,
1358  H_kv,
1359  ad);
1360 
1361  // RoPE for the new token at absolute position `p->rope_pos_offset`.
1362  if (p->rope_cos && p->rope_sin) {
1363  rope_forward_qk(q_token,
1364  k_token,
1365  p->rope_cos,
1366  p->rope_sin,
1367  H,
1368  H_kv,
1369  /*num_tokens=*/1,
1370  hd,
1371  ad,
1372  p->rope_pos_offset);
1373  }
1374 
1375  // Update KV cache (stores k/v for this token and clears padded lanes).
1376  kv_cache_write_head_major(k_token,
1377  v_token,
1378  p->k,
1379  p->v,
1380  H_kv,
1381  token_index,
1382  cache_capacity,
1383  hd,
1384  ad);
1385 
1386  // Decode attention for this token using the KV cache.
1388  p->k,
1389  p->v,
1390  attn_token,
1391  H,
1392  H_kv,
1393  /*kv_tokens=*/token_index + 1,
1394  cache_capacity,
1395  hd,
1396  ad);
1397 
1398  // Output projection (Wo) into token-major buffer (decode-specialized).
1400  p->wo,
1401  p->bo,
1402  proj_row,
1403  D,
1404  aligned_D,
1405  H,
1406  ad);
1407 
1408  // Residual + LN2 / RMSNorm.
1409  ck_residual_add_token_major(input_row,
1410  proj_row,
1411  residual_row,
1412  /*tokens=*/1,
1413  aligned_D);
1414 
1415  rmsnorm_forward(residual_row,
1416  p->ln2_gamma,
1417  ln2_row,
1418  ln2_rstd,
1419  /*tokens=*/1,
1420  D,
1421  aligned_D,
1422  p->eps);
1423 
1424  // MLP block for this token.
1425  int up_dim = 2 * aligned_intermediate;
1426  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
1427  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
1428 
1429  ck_mlp_swiglu_forward(ln2_row,
1430  p->w1,
1431  p->b1,
1432  p->w2,
1433  p->b2,
1434  fc1_row,
1435  swiglu_row,
1436  mlp_row,
1437  /*tokens=*/1,
1438  aligned_D,
1439  aligned_intermediate);
1440 
1441  // Final residual.
1442  ck_residual_add_token_major(residual_row,
1443  mlp_row,
1444  out_row,
1445  /*tokens=*/1,
1446  aligned_D);
1447 }
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
void ck_attention_flash_decode_wrapper(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
Wrapper to call TRUE flash attention from orchestration layer.
void ck_qkv_project_head_major_token(const float *input_row, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_attention_project_head_major_decode_token(const float *attn_token, const float *wo, const float *bo, float *out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)

References CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_flash_decode_wrapper(), ck_attention_project_head_major_decode_token(), ck_mlp_swiglu_forward(), ck_qkv_project_head_major_token(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::fc1_out, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, kv_cache_write_head_major(), CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln1_out, CKLayerForwardParams::ln1_rstd, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::proj_tmp, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.

◆ ck_layer_forward_rmsnorm_swiglu_decode_fused()

void ck_layer_forward_rmsnorm_swiglu_decode_fused ( const CKLayerForwardParams p,
int  token_index,
int  cache_capacity 
)

Definition at line 1449 of file ckernel_orchestration.c.

1452 {
1453  if (!p) {
1454  return;
1455  }
1456  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->ln1_out || !p->ln2_out ||
1457  !p->wq || !p->wk || !p->wv || !p->wo || !p->w1 || !p->w2 ||
1458  !p->k || !p->v || !p->swiglu_out ||
1459  !p->proj_tmp || !p->residual1 || !p->mlp_out || !p->output) {
1460  return;
1461  }
1462  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
1463  return;
1464  }
1465  if (p->num_heads <= 0 || p->num_kv_heads <= 0 || p->aligned_head_dim <= 0) {
1466  return;
1467  }
1468 
1469  const int D = p->embed_dim;
1470  const int aligned_D = p->aligned_embed_dim;
1471  const int H = p->num_heads;
1472  const int H_kv = p->num_kv_heads;
1473  const int hd = p->head_dim;
1474  const int ad = p->aligned_head_dim;
1475  const int aligned_intermediate = p->aligned_intermediate_dim;
1476 
1477  /* Decode buffers are single-token; token_index only applies to KV cache. */
1478  const size_t token_slot = 0;
1479  const float *input_row = p->input + token_slot * (size_t)aligned_D;
1480  float *ln1_row = p->ln1_out + token_slot * (size_t)aligned_D;
1481  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
1482  float *proj_row = p->proj_tmp + token_slot * (size_t)aligned_D;
1483  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
1484  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
1485  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
1486  float *out_row = p->output + token_slot * (size_t)aligned_D;
1487 
1488  float ln1_rstd_tmp = 0.0f;
1489  float ln2_rstd_tmp = 0.0f;
1490  float *ln1_rstd = p->ln1_rstd ? (p->ln1_rstd + token_slot) : &ln1_rstd_tmp;
1491  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
1492 
1493  // Scratch for a single token in head-major layout: [head, aligned_head_dim].
1494  size_t q_elems = (size_t)H * (size_t)ad;
1495  size_t kv_elems = (size_t)H_kv * (size_t)ad;
1496  float q_token[q_elems];
1497  float k_token[kv_elems];
1498  float v_token[kv_elems];
1499  float attn_token[q_elems];
1500 
1501  // LN1 / RMSNorm.
1502  rmsnorm_forward(input_row,
1503  p->ln1_gamma,
1504  ln1_row,
1505  ln1_rstd,
1506  /*tokens=*/1,
1507  D,
1508  aligned_D,
1509  p->eps);
1510 
1511  // Project Q/K/V for the new token.
1513  p->wq, p->bq,
1514  p->wk, p->bk,
1515  p->wv, p->bv,
1516  q_token, k_token, v_token,
1517  aligned_D,
1518  H,
1519  H_kv,
1520  ad);
1521 
1522  // RoPE for the new token at absolute position `p->rope_pos_offset`.
1523  if (p->rope_cos && p->rope_sin) {
1524  rope_forward_qk(q_token,
1525  k_token,
1526  p->rope_cos,
1527  p->rope_sin,
1528  H,
1529  H_kv,
1530  /*num_tokens=*/1,
1531  hd,
1532  ad,
1533  p->rope_pos_offset);
1534  }
1535 
1536  // Update KV cache (stores k/v for this token and clears padded lanes).
1537  kv_cache_write_head_major(k_token,
1538  v_token,
1539  p->k,
1540  p->v,
1541  H_kv,
1542  token_index,
1543  cache_capacity,
1544  hd,
1545  ad);
1546 
1547  // Decode attention for this token using the KV cache.
1549  p->k,
1550  p->v,
1551  attn_token,
1552  H,
1553  H_kv,
1554  /*kv_tokens=*/token_index + 1,
1555  cache_capacity,
1556  hd,
1557  ad);
1558 
1559  // Output projection (Wo) into token-major buffer (decode-specialized).
1561  p->wo,
1562  p->bo,
1563  proj_row,
1564  D,
1565  aligned_D,
1566  H,
1567  ad);
1568 
1569  // Residual + LN2 / RMSNorm.
1570  ck_residual_add_token_major(input_row,
1571  proj_row,
1572  residual_row,
1573  /*tokens=*/1,
1574  aligned_D);
1575 
1576  rmsnorm_forward(residual_row,
1577  p->ln2_gamma,
1578  ln2_row,
1579  ln2_rstd,
1580  /*tokens=*/1,
1581  D,
1582  aligned_D,
1583  p->eps);
1584 
1585  // MLP block for this token (fully fused - all 3 projections in one pass).
1586  // Eliminates DRAM round-trip for swiglu intermediate values.
1588  p->w1,
1589  p->b1,
1590  p->w2,
1591  p->b2,
1592  mlp_row,
1593  aligned_D,
1594  aligned_intermediate);
1595 
1596  // Final residual.
1597  ck_residual_add_token_major(residual_row,
1598  mlp_row,
1599  out_row,
1600  /*tokens=*/1,
1601  aligned_D);
1602 }
void ck_mlp_swiglu_forward_fully_fused_token(const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)

References CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_flash_decode_wrapper(), ck_attention_project_head_major_decode_token(), ck_mlp_swiglu_forward_fully_fused_token(), ck_qkv_project_head_major_token(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, kv_cache_write_head_major(), CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln1_out, CKLayerForwardParams::ln1_rstd, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::proj_tmp, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.

◆ ck_layer_forward_rmsnorm_swiglu_decode_q4_k()

void ck_layer_forward_rmsnorm_swiglu_decode_q4_k ( const CKLayerForwardParamsQ4K p,
int  token_index,
int  cache_capacity 
)

Definition at line 2117 of file ckernel_orchestration.c.

2120 {
2121  if (!p) {
2122  return;
2123  }
2124  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->ln1_out || !p->ln2_out ||
2125  !p->wq || !p->wk || !p->wv || !p->wo || !p->w1 || !p->w2 ||
2126  !p->k || !p->v ||
2127  !p->proj_tmp || !p->residual1 || !p->fc1_out || !p->swiglu_out || !p->mlp_out || !p->output) {
2128  return;
2129  }
2130  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
2131  return;
2132  }
2133 
2134  const int D = p->embed_dim;
2135  const int aligned_D = p->aligned_embed_dim;
2136  const int H = p->num_heads;
2137  const int H_kv = p->num_kv_heads;
2138  const int hd = p->head_dim;
2139  const int ad = p->aligned_head_dim;
2140  const int aligned_intermediate = p->aligned_intermediate_dim;
2141  const int K_concat = H * ad;
2142 
2143  /* Decode buffers are single-token; token_index only applies to KV cache. */
2144  const size_t token_slot = 0;
2145  const float *input_row = p->input + token_slot * (size_t)aligned_D;
2146  float *ln1_row = p->ln1_out + token_slot * (size_t)aligned_D;
2147  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
2148  float *proj_row = p->proj_tmp + token_slot * (size_t)aligned_D;
2149  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
2150  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
2151  float *out_row = p->output + token_slot * (size_t)aligned_D;
2152 
2153  float ln1_rstd_tmp = 0.0f;
2154  float ln2_rstd_tmp = 0.0f;
2155  float *ln1_rstd = p->ln1_rstd ? (p->ln1_rstd + token_slot) : &ln1_rstd_tmp;
2156  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
2157 
2158  /* Scratch for a single token in head-major layout: [head, aligned_head_dim]. */
2159  size_t q_elems = (size_t)H * (size_t)ad;
2160  size_t kv_elems = (size_t)H_kv * (size_t)ad;
2161  float q_token[q_elems];
2162  float k_token[kv_elems];
2163  float v_token[kv_elems];
2164  float attn_token[q_elems];
2165 
2166  /* LN1 / RMSNorm. */
2167  ck_debug_check_buffer("input_row", input_row, aligned_D);
2168  rmsnorm_forward(input_row,
2169  p->ln1_gamma,
2170  ln1_row,
2171  ln1_rstd,
2172  /*tokens=*/1,
2173  D,
2174  aligned_D,
2175  p->eps);
2176  ck_debug_check_buffer("ln1_out (after rmsnorm)", ln1_row, aligned_D);
2177 
2179  if ((aligned_D % QK_K) == 0 && (aligned_intermediate % QK_K) == 0) {
2180  const int q8_blocks_embed = aligned_D / QK_K;
2181  const int q8_blocks_inter = aligned_intermediate / QK_K;
2182  const int q8_blocks_max = (q8_blocks_embed > q8_blocks_inter) ? q8_blocks_embed : q8_blocks_inter;
2183  block_q8_K q8_buf[q8_blocks_max];
2184 
2185  /* Project Q/K/V with Q8_K activations. */
2186  quantize_row_q8_k(ln1_row, q8_buf, aligned_D);
2187  ck_debug_check_q8k("q8_buf (after quantize)", q8_buf, q8_blocks_embed);
2188  ck_debug_check_q4k_weights("wq weights", p->wq, (aligned_D / QK_K) * (H * ad));
2190  p->wq, p->bq,
2191  p->wk, p->bk,
2192  p->wv, p->bv,
2193  q_token, k_token, v_token,
2194  aligned_D,
2195  H,
2196  H_kv,
2197  ad);
2198  ck_debug_check_buffer("q_token (after QKV proj)", q_token, (int)q_elems);
2199  ck_debug_check_buffer("k_token (after QKV proj)", k_token, (int)kv_elems);
2200  ck_debug_check_buffer("v_token (after QKV proj)", v_token, (int)kv_elems);
2201 
2202  /* RoPE for the new token at absolute position `p->rope_pos_offset`. */
2203  if (p->rope_cos && p->rope_sin) {
2204  rope_forward_qk(q_token,
2205  k_token,
2206  p->rope_cos,
2207  p->rope_sin,
2208  H,
2209  H_kv,
2210  /*num_tokens=*/1,
2211  hd,
2212  ad,
2213  p->rope_pos_offset);
2214  }
2215 
2216  /* Update KV cache. */
2217  kv_cache_write_head_major(k_token,
2218  v_token,
2219  p->k,
2220  p->v,
2221  H_kv,
2222  token_index,
2223  cache_capacity,
2224  hd,
2225  ad);
2226 
2227  /* Decode attention for this token using the KV cache. */
2229  p->k,
2230  p->v,
2231  attn_token,
2232  H,
2233  H_kv,
2234  /*kv_tokens=*/token_index + 1,
2235  cache_capacity,
2236  hd,
2237  ad);
2238  ck_debug_check_buffer("attn_token (after attention)", attn_token, (int)q_elems);
2239 
2240  /* Quantized output projection (Wo) with Q8_K activations. */
2241  quantize_row_q8_k(attn_token, q8_buf, aligned_D);
2242  gemm_nt_q4_k_q8_k(q8_buf,
2243  p->wo,
2244  p->bo,
2245  proj_row,
2246  /*M=*/1,
2247  aligned_D,
2248  /*K=*/K_concat);
2249  ck_debug_check_buffer("proj_row (after Wo proj)", proj_row, aligned_D);
2250 
2251  for (int j = D; j < aligned_D; ++j) {
2252  proj_row[j] = 0.0f;
2253  }
2254 
2255  /* Residual + LN2 / RMSNorm. */
2256  ck_residual_add_token_major(input_row,
2257  proj_row,
2258  residual_row,
2259  /*tokens=*/1,
2260  aligned_D);
2261 
2262  rmsnorm_forward(residual_row,
2263  p->ln2_gamma,
2264  ln2_row,
2265  ln2_rstd,
2266  /*tokens=*/1,
2267  D,
2268  aligned_D,
2269  p->eps);
2270 
2271  /* MLP block for this token (Q8_K activations). */
2272  int up_dim = 2 * aligned_intermediate;
2273  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
2274  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
2275 
2277  p->w1,
2278  p->b1,
2279  p->w2,
2280  p->b2,
2281  fc1_row,
2282  swiglu_row,
2283  mlp_row,
2284  aligned_D,
2285  aligned_intermediate);
2286  ck_debug_check_buffer("mlp_row (after MLP)", mlp_row, aligned_D);
2287 
2288  /* Final residual. */
2289  ck_residual_add_token_major(residual_row,
2290  mlp_row,
2291  out_row,
2292  /*tokens=*/1,
2293  aligned_D);
2294  ck_debug_check_buffer("out_row (final output)", out_row, aligned_D);
2295  return;
2296  }
2297  }
2298 
2299  /* Project Q/K/V for the new token (Q4_K weights). */
2301  p->wq, p->bq,
2302  p->wk, p->bk,
2303  p->wv, p->bv,
2304  q_token, k_token, v_token,
2305  aligned_D,
2306  H,
2307  H_kv,
2308  ad);
2309 
2310  /* RoPE for the new token at absolute position `p->rope_pos_offset`. */
2311  if (p->rope_cos && p->rope_sin) {
2312  rope_forward_qk(q_token,
2313  k_token,
2314  p->rope_cos,
2315  p->rope_sin,
2316  H,
2317  H_kv,
2318  /*num_tokens=*/1,
2319  hd,
2320  ad,
2321  p->rope_pos_offset);
2322  }
2323 
2324  /* Update KV cache. */
2325  kv_cache_write_head_major(k_token,
2326  v_token,
2327  p->k,
2328  p->v,
2329  H_kv,
2330  token_index,
2331  cache_capacity,
2332  hd,
2333  ad);
2334 
2335  /* Decode attention for this token using the KV cache. */
2337  p->k,
2338  p->v,
2339  attn_token,
2340  H,
2341  H_kv,
2342  /*kv_tokens=*/token_index + 1,
2343  cache_capacity,
2344  hd,
2345  ad);
2346 
2347  /* Quantized output projection: Wo is stored as a flattened Q4_K matrix. */
2348  gemm_nt_q4_k(attn_token,
2349  p->wo,
2350  p->bo,
2351  proj_row,
2352  /*M=*/1,
2353  aligned_D,
2354  /*K=*/K_concat);
2355 
2356  for (int j = D; j < aligned_D; ++j) {
2357  proj_row[j] = 0.0f;
2358  }
2359 
2360  /* Residual + LN2 / RMSNorm. */
2361  ck_residual_add_token_major(input_row,
2362  proj_row,
2363  residual_row,
2364  /*tokens=*/1,
2365  aligned_D);
2366 
2367  rmsnorm_forward(residual_row,
2368  p->ln2_gamma,
2369  ln2_row,
2370  ln2_rstd,
2371  /*tokens=*/1,
2372  D,
2373  aligned_D,
2374  p->eps);
2375 
2376  /* MLP block for this token. */
2377  int up_dim = 2 * aligned_intermediate;
2378  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
2379  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
2380 
2382  p->w1,
2383  p->b1,
2384  p->w2,
2385  p->b2,
2386  fc1_row,
2387  swiglu_row,
2388  mlp_row,
2389  /*tokens=*/1,
2390  aligned_D,
2391  aligned_intermediate);
2392 
2393  /* Final residual. */
2394  ck_residual_add_token_major(residual_row,
2395  mlp_row,
2396  out_row,
2397  /*tokens=*/1,
2398  aligned_D);
2399 }
static void ck_qkv_project_head_major_token_q4_k_q8_k(const block_q8_K *input_q8, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static int ck_q8k_activations_enabled(void)
static void ck_mlp_swiglu_forward_q4_k_q8_k(const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_debug_check_q8k(const char *stage, const void *q8_buf, int num_blocks)
static void ck_mlp_swiglu_forward_q4_k(const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_debug_check_q4k_weights(const char *stage, const void *q4_buf, int num_blocks)
static void ck_debug_check_buffer(const char *stage, const float *buf, int size)
static void ck_qkv_project_head_major_token_q4_k(const float *input_row, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)

References CKLayerForwardParamsQ4K::aligned_embed_dim, CKLayerForwardParamsQ4K::aligned_head_dim, CKLayerForwardParamsQ4K::aligned_intermediate_dim, attention_forward_decode_head_major_gqa_regular(), CKLayerForwardParamsQ4K::b1, CKLayerForwardParamsQ4K::b2, CKLayerForwardParamsQ4K::bk, CKLayerForwardParamsQ4K::bo, CKLayerForwardParamsQ4K::bq, CKLayerForwardParamsQ4K::bv, ck_attention_flash_decode_wrapper(), ck_debug_check_buffer(), ck_debug_check_q4k_weights(), ck_debug_check_q8k(), ck_mlp_swiglu_forward_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_q8k_activations_enabled(), ck_qkv_project_head_major_token_q4_k(), ck_qkv_project_head_major_token_q4_k_q8_k(), ck_residual_add_token_major(), CKLayerForwardParamsQ4K::embed_dim, CKLayerForwardParamsQ4K::eps, CKLayerForwardParamsQ4K::fc1_out, gemm_nt_q4_k(), gemm_nt_q4_k_q8_k(), CKLayerForwardParamsQ4K::head_dim, CKLayerForwardParamsQ4K::input, CKLayerForwardParamsQ4K::k, kv_cache_write_head_major(), CKLayerForwardParamsQ4K::ln1_gamma, CKLayerForwardParamsQ4K::ln1_out, CKLayerForwardParamsQ4K::ln1_rstd, CKLayerForwardParamsQ4K::ln2_gamma, CKLayerForwardParamsQ4K::ln2_out, CKLayerForwardParamsQ4K::ln2_rstd, CKLayerForwardParamsQ4K::mlp_out, CKLayerForwardParamsQ4K::num_heads, CKLayerForwardParamsQ4K::num_kv_heads, CKLayerForwardParamsQ4K::output, CKLayerForwardParamsQ4K::proj_tmp, QK_K, quantize_row_q8_k(), CKLayerForwardParamsQ4K::residual1, rmsnorm_forward(), CKLayerForwardParamsQ4K::rope_cos, rope_forward_qk(), CKLayerForwardParamsQ4K::rope_pos_offset, CKLayerForwardParamsQ4K::rope_sin, CKLayerForwardParamsQ4K::swiglu_out, CKLayerForwardParamsQ4K::v, CKLayerForwardParamsQ4K::w1, CKLayerForwardParamsQ4K::w2, CKLayerForwardParamsQ4K::wk, CKLayerForwardParamsQ4K::wo, CKLayerForwardParamsQ4K::wq, and CKLayerForwardParamsQ4K::wv.

◆ ck_layer_forward_rmsnorm_swiglu_decode_quant()

void ck_layer_forward_rmsnorm_swiglu_decode_quant ( const CKLayerForwardParamsQ4K p,
int  token_index,
int  cache_capacity 
)

Definition at line 2512 of file ckernel_orchestration.c.

2515 {
2516  if (!p) {
2517  return;
2518  }
2519  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->ln1_out || !p->ln2_out ||
2520  !p->wq || !p->wk || !p->wv || !p->wo || !p->w1 || !p->w2 ||
2521  !p->k || !p->v ||
2522  !p->proj_tmp || !p->proj_scratch || !p->residual1 || !p->fc1_out || !p->swiglu_out || !p->mlp_out || !p->output) {
2523  return;
2524  }
2525  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
2526  return;
2527  }
2528 
2529  const int D = p->embed_dim;
2530  const int aligned_D = p->aligned_embed_dim;
2531  const int H = p->num_heads;
2532  const int H_kv = p->num_kv_heads;
2533  const int hd = p->head_dim;
2534  const int ad = p->aligned_head_dim;
2535  const int aligned_intermediate = p->aligned_intermediate_dim;
2536  const int K_concat = H * ad;
2537 
2538  /* Decode buffers are single-token; token_index only applies to KV cache. */
2539  const size_t token_slot = 0;
2540  const float *input_row = p->input + token_slot * (size_t)aligned_D;
2541  float *ln1_row = p->ln1_out + token_slot * (size_t)aligned_D;
2542  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
2543  float *proj_row = p->proj_tmp + token_slot * (size_t)aligned_D;
2544  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
2545  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
2546  float *out_row = p->output + token_slot * (size_t)aligned_D;
2547 
2548  float ln1_rstd_tmp = 0.0f;
2549  float ln2_rstd_tmp = 0.0f;
2550  float *ln1_rstd = p->ln1_rstd ? (p->ln1_rstd + token_slot) : &ln1_rstd_tmp;
2551  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
2552 
2553  size_t q_elems = (size_t)H * (size_t)ad;
2554  size_t kv_elems = (size_t)H_kv * (size_t)ad;
2555  float q_token[q_elems];
2556  float k_token[kv_elems];
2557  float v_token[kv_elems];
2558  float attn_token[q_elems];
2559 
2560  rmsnorm_forward(input_row,
2561  p->ln1_gamma,
2562  ln1_row,
2563  ln1_rstd,
2564  /*tokens=*/1,
2565  D,
2566  aligned_D,
2567  p->eps);
2568 
2570  p->wq, p->bq, p->wq_dtype,
2571  p->wk, p->bk, p->wk_dtype,
2572  p->wv, p->bv, p->wv_dtype,
2573  q_token, k_token, v_token,
2574  aligned_D,
2575  H,
2576  H_kv,
2577  ad);
2578 
2579  if (p->rope_cos && p->rope_sin) {
2580  rope_forward_qk(q_token,
2581  k_token,
2582  p->rope_cos,
2583  p->rope_sin,
2584  H,
2585  H_kv,
2586  /*num_tokens=*/1,
2587  hd,
2588  ad,
2589  p->rope_pos_offset);
2590  }
2591 
2592  kv_cache_write_head_major(k_token,
2593  v_token,
2594  p->k,
2595  p->v,
2596  H_kv,
2597  token_index,
2598  cache_capacity,
2599  hd,
2600  ad);
2601 
2603  p->k,
2604  p->v,
2605  attn_token,
2606  H,
2607  H_kv,
2608  /*kv_tokens=*/token_index + 1,
2609  cache_capacity,
2610  hd,
2611  ad);
2612 
2613  if (p->wo_dtype == CK_DT_FP32) {
2615  (const float *)p->wo,
2616  p->bo,
2617  proj_row,
2618  D,
2619  aligned_D,
2620  H,
2621  ad);
2622  } else {
2623  /* Quantized attention output projection - handle all quant types */
2624  ck_gemm_nt_quant(attn_token,
2625  p->wo,
2626  p->bo,
2627  proj_row,
2628  /*M=*/1,
2629  aligned_D,
2630  /*K=*/K_concat,
2631  p->wo_dtype);
2632  for (int j = D; j < aligned_D; ++j) {
2633  proj_row[j] = 0.0f;
2634  }
2635  }
2636 
2637  ck_residual_add_token_major(input_row,
2638  proj_row,
2639  residual_row,
2640  /*tokens=*/1,
2641  aligned_D);
2642 
2643  rmsnorm_forward(residual_row,
2644  p->ln2_gamma,
2645  ln2_row,
2646  ln2_rstd,
2647  /*tokens=*/1,
2648  D,
2649  aligned_D,
2650  p->eps);
2651 
2652  int up_dim = 2 * aligned_intermediate;
2653  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
2654  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
2655 
2657  p->w1,
2658  p->b1,
2659  p->w1_dtype,
2660  p->w2,
2661  p->b2,
2662  p->w2_dtype,
2663  fc1_row,
2664  swiglu_row,
2665  mlp_row,
2666  /*tokens=*/1,
2667  aligned_D,
2668  aligned_intermediate);
2669 
2670  ck_residual_add_token_major(residual_row,
2671  mlp_row,
2672  out_row,
2673  /*tokens=*/1,
2674  aligned_D);
2675 }
static void ck_mlp_swiglu_forward_quant(const float *input, const void *w1, const float *b1, CKDataType w1_dtype, const void *w2, const float *b2, CKDataType w2_dtype, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_qkv_project_head_major_token_quant(const float *input_row, const void *wq, const float *bq, CKDataType wq_dtype, const void *wk, const float *bk, CKDataType wk_dtype, const void *wv, const float *bv, CKDataType wv_dtype, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)

References CKLayerForwardParamsQ4K::aligned_embed_dim, CKLayerForwardParamsQ4K::aligned_head_dim, CKLayerForwardParamsQ4K::aligned_intermediate_dim, CKLayerForwardParamsQ4K::b1, CKLayerForwardParamsQ4K::b2, CKLayerForwardParamsQ4K::bk, CKLayerForwardParamsQ4K::bo, CKLayerForwardParamsQ4K::bq, CKLayerForwardParamsQ4K::bv, ck_attention_flash_decode_wrapper(), ck_attention_project_head_major_decode_token(), CK_DT_FP32, ck_gemm_nt_quant(), ck_mlp_swiglu_forward_quant(), ck_qkv_project_head_major_token_quant(), ck_residual_add_token_major(), CKLayerForwardParamsQ4K::embed_dim, CKLayerForwardParamsQ4K::eps, CKLayerForwardParamsQ4K::fc1_out, CKLayerForwardParamsQ4K::head_dim, CKLayerForwardParamsQ4K::input, CKLayerForwardParamsQ4K::k, kv_cache_write_head_major(), CKLayerForwardParamsQ4K::ln1_gamma, CKLayerForwardParamsQ4K::ln1_out, CKLayerForwardParamsQ4K::ln1_rstd, CKLayerForwardParamsQ4K::ln2_gamma, CKLayerForwardParamsQ4K::ln2_out, CKLayerForwardParamsQ4K::ln2_rstd, CKLayerForwardParamsQ4K::mlp_out, CKLayerForwardParamsQ4K::num_heads, CKLayerForwardParamsQ4K::num_kv_heads, CKLayerForwardParamsQ4K::output, CKLayerForwardParamsQ4K::proj_scratch, CKLayerForwardParamsQ4K::proj_tmp, CKLayerForwardParamsQ4K::residual1, rmsnorm_forward(), CKLayerForwardParamsQ4K::rope_cos, rope_forward_qk(), CKLayerForwardParamsQ4K::rope_pos_offset, CKLayerForwardParamsQ4K::rope_sin, CKLayerForwardParamsQ4K::swiglu_out, CKLayerForwardParamsQ4K::v, CKLayerForwardParamsQ4K::w1, CKLayerForwardParamsQ4K::w1_dtype, CKLayerForwardParamsQ4K::w2, CKLayerForwardParamsQ4K::w2_dtype, CKLayerForwardParamsQ4K::wk, CKLayerForwardParamsQ4K::wk_dtype, CKLayerForwardParamsQ4K::wo, CKLayerForwardParamsQ4K::wo_dtype, CKLayerForwardParamsQ4K::wq, CKLayerForwardParamsQ4K::wq_dtype, CKLayerForwardParamsQ4K::wv, and CKLayerForwardParamsQ4K::wv_dtype.

◆ ck_layer_forward_rmsnorm_swiglu_q4_k()

void ck_layer_forward_rmsnorm_swiglu_q4_k ( const CKLayerForwardParamsQ4K p)

Definition at line 1910 of file ckernel_orchestration.c.

1911 {
1912  if (!p) {
1913  return;
1914  }
1915 
1916  const int aligned_D = p->aligned_embed_dim;
1917  const int aligned_intermediate = p->aligned_intermediate_dim;
1918 
1920  p->ln1_gamma,
1921  p->ln1_out,
1922  p->ln1_rstd,
1923  p->tokens,
1924  p->embed_dim,
1925  aligned_D,
1926  p->eps);
1927 
1929  if ((aligned_D % QK_K) == 0 && (aligned_intermediate % QK_K) == 0) {
1931  p->wq, p->bq,
1932  p->wk, p->bk,
1933  p->wv, p->bv,
1934  p->q, p->k, p->v,
1935  p->tokens,
1936  p->tokens,
1937  aligned_D,
1938  p->num_heads,
1939  p->num_kv_heads,
1940  p->aligned_head_dim);
1941 
1942  if (p->rope_cos && p->rope_sin) {
1943  rope_forward_qk(p->q,
1944  p->k,
1945  p->rope_cos,
1946  p->rope_sin,
1947  p->num_heads,
1948  p->num_kv_heads,
1949  p->tokens,
1950  p->head_dim,
1951  p->aligned_head_dim,
1952  p->rope_pos_offset);
1953  }
1954 
1955  if (p->scores) {
1957  p->k,
1958  p->v,
1959  p->scores,
1960  p->attn_out,
1961  p->num_heads,
1962  p->num_kv_heads,
1963  p->tokens,
1964  p->head_dim,
1965  p->aligned_head_dim,
1967  } else {
1969  p->k,
1970  p->v,
1971  p->attn_out,
1972  p->num_heads,
1973  p->num_kv_heads,
1974  p->tokens,
1975  p->head_dim,
1976  p->aligned_head_dim);
1977  }
1978 
1980  p->wo,
1981  p->bo,
1982  p->proj_tmp,
1983  p->tokens,
1984  aligned_D,
1985  p->num_heads,
1986  p->aligned_head_dim);
1987 
1989  p->proj_tmp,
1990  p->residual1,
1991  p->tokens,
1992  aligned_D);
1993 
1995  p->ln2_gamma,
1996  p->ln2_out,
1997  p->ln2_rstd,
1998  p->tokens,
1999  p->embed_dim,
2000  aligned_D,
2001  p->eps);
2002 
2004  p->w1,
2005  p->b1,
2006  p->w2,
2007  p->b2,
2008  p->fc1_out,
2009  p->swiglu_out,
2010  p->mlp_out,
2011  p->tokens,
2012  aligned_D,
2013  aligned_intermediate);
2014 
2016  p->mlp_out,
2017  p->output,
2018  p->tokens,
2019  aligned_D);
2020  return;
2021  }
2022  }
2023 
2025  p->wq, p->bq,
2026  p->wk, p->bk,
2027  p->wv, p->bv,
2028  p->q, p->k, p->v,
2029  p->tokens,
2030  p->tokens,
2031  aligned_D,
2032  p->num_heads,
2033  p->num_kv_heads,
2034  p->aligned_head_dim);
2035 
2036  if (p->rope_cos && p->rope_sin) {
2037  rope_forward_qk(p->q,
2038  p->k,
2039  p->rope_cos,
2040  p->rope_sin,
2041  p->num_heads,
2042  p->num_kv_heads,
2043  p->tokens,
2044  p->head_dim,
2045  p->aligned_head_dim,
2046  p->rope_pos_offset);
2047  }
2048 
2049  if (p->scores) {
2051  p->k,
2052  p->v,
2053  p->scores,
2054  p->attn_out,
2055  p->num_heads,
2056  p->num_kv_heads,
2057  p->tokens,
2058  p->head_dim,
2059  p->aligned_head_dim,
2061  } else {
2063  p->k,
2064  p->v,
2065  p->attn_out,
2066  p->num_heads,
2067  p->num_kv_heads,
2068  p->tokens,
2069  p->head_dim,
2070  p->aligned_head_dim);
2071  }
2072 
2074  p->wo,
2075  p->bo,
2076  p->proj_tmp,
2077  p->proj_scratch,
2078  p->tokens,
2079  p->aligned_embed_dim,
2080  p->num_heads,
2081  p->aligned_head_dim);
2082 
2084  p->proj_tmp,
2085  p->residual1,
2086  p->tokens,
2087  p->aligned_embed_dim);
2088 
2090  p->ln2_gamma,
2091  p->ln2_out,
2092  p->ln2_rstd,
2093  p->tokens,
2094  p->embed_dim,
2095  p->aligned_embed_dim,
2096  p->eps);
2097 
2099  p->w1,
2100  p->b1,
2101  p->w2,
2102  p->b2,
2103  p->fc1_out,
2104  p->swiglu_out,
2105  p->mlp_out,
2106  p->tokens,
2107  p->aligned_embed_dim,
2109 
2111  p->mlp_out,
2112  p->output,
2113  p->tokens,
2114  p->aligned_embed_dim);
2115 }
static void ck_attention_project_head_major_q4_k(const float *attn_out, const void *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void ck_qkv_project_head_major_q4_k_q8_k(const float *input, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_attention_project_head_major_q4_k_q8_k(const float *attn_out, const void *wo, const float *bo, float *out, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void ck_qkv_project_head_major_q4_k(const float *input, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_mlp_swiglu_forward_q4_k_q8_k_prefill(const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)

References CKLayerForwardParamsQ4K::aligned_context_window, CKLayerForwardParamsQ4K::aligned_embed_dim, CKLayerForwardParamsQ4K::aligned_head_dim, CKLayerForwardParamsQ4K::aligned_intermediate_dim, attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_flash(), CKLayerForwardParamsQ4K::attn_out, CKLayerForwardParamsQ4K::b1, CKLayerForwardParamsQ4K::b2, CKLayerForwardParamsQ4K::bk, CKLayerForwardParamsQ4K::bo, CKLayerForwardParamsQ4K::bq, CKLayerForwardParamsQ4K::bv, ck_attention_project_head_major_q4_k(), ck_attention_project_head_major_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_q8k_activations_enabled(), ck_qkv_project_head_major_q4_k(), ck_qkv_project_head_major_q4_k_q8_k(), ck_residual_add_token_major(), CKLayerForwardParamsQ4K::embed_dim, CKLayerForwardParamsQ4K::eps, CKLayerForwardParamsQ4K::fc1_out, CKLayerForwardParamsQ4K::head_dim, CKLayerForwardParamsQ4K::input, CKLayerForwardParamsQ4K::k, CKLayerForwardParamsQ4K::ln1_gamma, CKLayerForwardParamsQ4K::ln1_out, CKLayerForwardParamsQ4K::ln1_rstd, CKLayerForwardParamsQ4K::ln2_gamma, CKLayerForwardParamsQ4K::ln2_out, CKLayerForwardParamsQ4K::ln2_rstd, CKLayerForwardParamsQ4K::mlp_out, CKLayerForwardParamsQ4K::num_heads, CKLayerForwardParamsQ4K::num_kv_heads, CKLayerForwardParamsQ4K::output, CKLayerForwardParamsQ4K::proj_scratch, CKLayerForwardParamsQ4K::proj_tmp, CKLayerForwardParamsQ4K::q, QK_K, CKLayerForwardParamsQ4K::residual1, rmsnorm_forward(), CKLayerForwardParamsQ4K::rope_cos, rope_forward_qk(), CKLayerForwardParamsQ4K::rope_pos_offset, CKLayerForwardParamsQ4K::rope_sin, CKLayerForwardParamsQ4K::scores, CKLayerForwardParamsQ4K::swiglu_out, CKLayerForwardParamsQ4K::tokens, CKLayerForwardParamsQ4K::v, CKLayerForwardParamsQ4K::w1, CKLayerForwardParamsQ4K::w2, CKLayerForwardParamsQ4K::wk, CKLayerForwardParamsQ4K::wo, CKLayerForwardParamsQ4K::wq, and CKLayerForwardParamsQ4K::wv.

◆ ck_layer_forward_rmsnorm_swiglu_quant()

void ck_layer_forward_rmsnorm_swiglu_quant ( const CKLayerForwardParamsQ4K p)

Definition at line 2401 of file ckernel_orchestration.c.

2402 {
2403  if (!p) {
2404  return;
2405  }
2406 
2408  p->ln1_gamma,
2409  p->ln1_out,
2410  p->ln1_rstd,
2411  p->tokens,
2412  p->embed_dim,
2413  p->aligned_embed_dim,
2414  p->eps);
2415 
2417  p->wq, p->bq, p->wq_dtype,
2418  p->wk, p->bk, p->wk_dtype,
2419  p->wv, p->bv, p->wv_dtype,
2420  p->q, p->k, p->v,
2421  p->tokens,
2422  p->tokens,
2423  p->aligned_embed_dim,
2424  p->num_heads,
2425  p->num_kv_heads,
2426  p->aligned_head_dim);
2427 
2428  if (p->rope_cos && p->rope_sin) {
2429  rope_forward_qk(p->q,
2430  p->k,
2431  p->rope_cos,
2432  p->rope_sin,
2433  p->num_heads,
2434  p->num_kv_heads,
2435  p->tokens,
2436  p->head_dim,
2437  p->aligned_head_dim,
2438  p->rope_pos_offset);
2439  }
2440 
2441  if (p->scores) {
2443  p->k,
2444  p->v,
2445  p->scores,
2446  p->attn_out,
2447  p->num_heads,
2448  p->num_kv_heads,
2449  p->tokens,
2450  p->head_dim,
2451  p->aligned_head_dim,
2453  } else {
2455  p->k,
2456  p->v,
2457  p->attn_out,
2458  p->num_heads,
2459  p->num_kv_heads,
2460  p->tokens,
2461  p->head_dim,
2462  p->aligned_head_dim);
2463  }
2464 
2466  p->wo,
2467  p->bo,
2468  p->proj_tmp,
2469  p->proj_scratch,
2470  p->tokens,
2471  p->aligned_embed_dim,
2472  p->num_heads,
2473  p->aligned_head_dim,
2474  p->wo_dtype);
2475 
2477  p->proj_tmp,
2478  p->residual1,
2479  p->tokens,
2480  p->aligned_embed_dim);
2481 
2483  p->ln2_gamma,
2484  p->ln2_out,
2485  p->ln2_rstd,
2486  p->tokens,
2487  p->embed_dim,
2488  p->aligned_embed_dim,
2489  p->eps);
2490 
2492  p->w1,
2493  p->b1,
2494  p->w1_dtype,
2495  p->w2,
2496  p->b2,
2497  p->w2_dtype,
2498  p->fc1_out,
2499  p->swiglu_out,
2500  p->mlp_out,
2501  p->tokens,
2502  p->aligned_embed_dim,
2504 
2506  p->mlp_out,
2507  p->output,
2508  p->tokens,
2509  p->aligned_embed_dim);
2510 }
static void ck_qkv_project_head_major_quant(const float *input, const void *wq, const float *bq, CKDataType wq_dtype, const void *wk, const float *bk, CKDataType wk_dtype, const void *wv, const float *bv, CKDataType wv_dtype, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_attention_project_head_major_quant(const float *attn_out, const void *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, CKDataType wo_dtype)

References CKLayerForwardParamsQ4K::aligned_context_window, CKLayerForwardParamsQ4K::aligned_embed_dim, CKLayerForwardParamsQ4K::aligned_head_dim, CKLayerForwardParamsQ4K::aligned_intermediate_dim, attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_flash(), CKLayerForwardParamsQ4K::attn_out, CKLayerForwardParamsQ4K::b1, CKLayerForwardParamsQ4K::b2, CKLayerForwardParamsQ4K::bk, CKLayerForwardParamsQ4K::bo, CKLayerForwardParamsQ4K::bq, CKLayerForwardParamsQ4K::bv, ck_attention_project_head_major_quant(), ck_mlp_swiglu_forward_quant(), ck_qkv_project_head_major_quant(), ck_residual_add_token_major(), CKLayerForwardParamsQ4K::embed_dim, CKLayerForwardParamsQ4K::eps, CKLayerForwardParamsQ4K::fc1_out, CKLayerForwardParamsQ4K::head_dim, CKLayerForwardParamsQ4K::input, CKLayerForwardParamsQ4K::k, CKLayerForwardParamsQ4K::ln1_gamma, CKLayerForwardParamsQ4K::ln1_out, CKLayerForwardParamsQ4K::ln1_rstd, CKLayerForwardParamsQ4K::ln2_gamma, CKLayerForwardParamsQ4K::ln2_out, CKLayerForwardParamsQ4K::ln2_rstd, CKLayerForwardParamsQ4K::mlp_out, CKLayerForwardParamsQ4K::num_heads, CKLayerForwardParamsQ4K::num_kv_heads, CKLayerForwardParamsQ4K::output, CKLayerForwardParamsQ4K::proj_scratch, CKLayerForwardParamsQ4K::proj_tmp, CKLayerForwardParamsQ4K::q, CKLayerForwardParamsQ4K::residual1, rmsnorm_forward(), CKLayerForwardParamsQ4K::rope_cos, rope_forward_qk(), CKLayerForwardParamsQ4K::rope_pos_offset, CKLayerForwardParamsQ4K::rope_sin, CKLayerForwardParamsQ4K::scores, CKLayerForwardParamsQ4K::swiglu_out, CKLayerForwardParamsQ4K::tokens, CKLayerForwardParamsQ4K::v, CKLayerForwardParamsQ4K::w1, CKLayerForwardParamsQ4K::w1_dtype, CKLayerForwardParamsQ4K::w2, CKLayerForwardParamsQ4K::w2_dtype, CKLayerForwardParamsQ4K::wk, CKLayerForwardParamsQ4K::wk_dtype, CKLayerForwardParamsQ4K::wo, CKLayerForwardParamsQ4K::wo_dtype, CKLayerForwardParamsQ4K::wq, CKLayerForwardParamsQ4K::wq_dtype, CKLayerForwardParamsQ4K::wv, and CKLayerForwardParamsQ4K::wv_dtype.

◆ ck_layer_forward_rmsnorm_swiglu_ref()

void ck_layer_forward_rmsnorm_swiglu_ref ( const CKLayerForwardParams p)

Definition at line 1104 of file ckernel_orchestration.c.

1105 {
1106  if (!p) {
1107  return;
1108  }
1109 
1111  p->ln1_gamma,
1112  p->ln1_out,
1113  p->ln1_rstd,
1114  p->tokens,
1115  p->embed_dim,
1116  p->aligned_embed_dim,
1117  p->eps);
1118 
1120  p->wq, p->bq,
1121  p->wk, p->bk,
1122  p->wv, p->bv,
1123  p->q, p->k, p->v,
1124  p->tokens,
1125  p->tokens,
1126  p->aligned_embed_dim,
1127  p->num_heads,
1128  p->num_kv_heads,
1129  p->aligned_head_dim);
1130 
1131  if (p->rope_cos && p->rope_sin) {
1132  rope_forward_qk(p->q,
1133  p->k,
1134  p->rope_cos,
1135  p->rope_sin,
1136  p->num_heads,
1137  p->num_kv_heads,
1138  p->tokens,
1139  p->head_dim,
1140  p->aligned_head_dim,
1141  p->rope_pos_offset);
1142  }
1143 
1144  if (p->scores) {
1146  p->k,
1147  p->v,
1148  p->scores,
1149  p->attn_out,
1150  p->num_heads,
1151  p->num_kv_heads,
1152  p->tokens,
1153  p->head_dim,
1154  p->aligned_head_dim,
1156  } else {
1158  p->k,
1159  p->v,
1160  p->attn_out,
1161  p->num_heads,
1162  p->num_kv_heads,
1163  p->tokens,
1164  p->head_dim,
1165  p->aligned_head_dim);
1166  }
1167 
1169  p->wo,
1170  p->bo,
1171  p->proj_tmp,
1172  p->proj_scratch,
1173  p->tokens,
1174  p->aligned_embed_dim,
1175  p->num_heads,
1176  p->aligned_head_dim);
1177 
1179  p->proj_tmp,
1180  p->residual1,
1181  p->tokens,
1182  p->aligned_embed_dim);
1183 
1185  p->ln2_gamma,
1186  p->ln2_out,
1187  p->ln2_rstd,
1188  p->tokens,
1189  p->embed_dim,
1190  p->aligned_embed_dim,
1191  p->eps);
1192 
1194  p->w1,
1195  p->b1,
1196  p->w2,
1197  p->b2,
1198  p->fc1_out,
1199  p->swiglu_out,
1200  p->mlp_out,
1201  p->tokens,
1202  p->aligned_embed_dim,
1204 
1206  p->mlp_out,
1207  p->output,
1208  p->tokens,
1209  p->aligned_embed_dim);
1210 }
static void ck_attention_project_head_major_ref(const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void ck_qkv_project_head_major_ref(const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_mlp_swiglu_forward_ref(const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)

References CKLayerForwardParams::aligned_context_window, CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_flash(), CKLayerForwardParams::attn_out, CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_project_head_major_ref(), ck_mlp_swiglu_forward_ref(), ck_qkv_project_head_major_ref(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::fc1_out, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln1_out, CKLayerForwardParams::ln1_rstd, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::proj_scratch, CKLayerForwardParams::proj_tmp, CKLayerForwardParams::q, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::scores, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::tokens, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.

◆ ck_mlp_swiglu_forward()

void ck_mlp_swiglu_forward ( const float *  input,
const float *  w1,
const float *  b1,
const float *  w2,
const float *  b2,
float *  fc1_out,
float *  swiglu_out,
float *  output,
int  tokens,
int  aligned_embed_dim,
int  aligned_intermediate_dim 
)

Definition at line 952 of file ckernel_orchestration.c.

963 {
964  int up_dim = 2 * aligned_intermediate_dim;
965  gemm_blocked_serial(input, w1, b1, fc1_out,
966  tokens, up_dim, aligned_embed_dim);
967 
968  swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
969 
970  gemm_blocked_serial(swiglu_out, w2, b2, output,
971  tokens, aligned_embed_dim, aligned_intermediate_dim);
972 }
void swiglu_forward(const float *input, float *output, int tokens, int dim)

References gemm_blocked_serial(), and swiglu_forward().

Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_decode(), and ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().

◆ ck_mlp_swiglu_forward_fully_fused_token()

void ck_mlp_swiglu_forward_fully_fused_token ( const float *  input_row,
const float *  w1,
const float *  b1,
const float *  w2,
const float *  b2,
float *  output_row,
int  aligned_embed_dim,
int  aligned_intermediate_dim 
)

Definition at line 1247 of file ckernel_orchestration.c.

1255 {
1256  if (!input_row || !w1 || !w2 || !output_row) {
1257  return;
1258  }
1259 
1260  // Split w1 into gate and up projections
1261  // w1 layout: [2 * aligned_intermediate_dim, aligned_embed_dim]
1262  // First half: W_gate [aligned_intermediate_dim, aligned_embed_dim]
1263  // Second half: W_up [aligned_intermediate_dim, aligned_embed_dim]
1264  const float *w_gate = w1;
1265  const float *w_up = w1 + (size_t)aligned_intermediate_dim * (size_t)aligned_embed_dim;
1266 
1267  // Split b1 into gate and up biases (if present)
1268  const float *b_gate = b1;
1269  const float *b_up = b1 ? (b1 + aligned_intermediate_dim) : NULL;
1270 
1271  // w2 is W_down: [aligned_embed_dim, aligned_intermediate_dim]
1272  const float *w_down = w2;
1273  const float *b_down = b2;
1274 
1275  // Call the fully fused kernel - eliminates DRAM round-trip for swiglu
1276  // Uses aligned dimensions since weights are stored with alignment padding
1277  fused_mlp_swiglu_decode_v2(input_row,
1278  w_gate,
1279  w_up,
1280  w_down,
1281  b_gate,
1282  b_up,
1283  b_down,
1284  output_row,
1285  aligned_embed_dim,
1286  aligned_intermediate_dim);
1287 }
void fused_mlp_swiglu_decode_v2(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)

References fused_mlp_swiglu_decode_v2().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_fused(), and ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().

◆ ck_mlp_swiglu_forward_fused_token()

void ck_mlp_swiglu_forward_fused_token ( const float *  input_row,
const float *  w1,
const float *  b1,
const float *  w2,
const float *  b2,
float *  swiglu_row,
float *  output_row,
int  aligned_embed_dim,
int  aligned_intermediate_dim 
)

Definition at line 1212 of file ckernel_orchestration.c.

1221 {
1222  if (!input_row || !w1 || !w2 || !swiglu_row || !output_row) {
1223  return;
1224  }
1225 
1226  const float *w_gate = w1;
1227  const float *w_up = w1 + (size_t)aligned_intermediate_dim * (size_t)aligned_embed_dim;
1228  const float *b_gate = b1;
1229  const float *b_up = b1 ? (b1 + aligned_intermediate_dim) : NULL;
1230 
1231  gemm_swiglu_fused(input_row,
1232  w_gate,
1233  w_up,
1234  b_gate,
1235  b_up,
1236  swiglu_row,
1237  /*M=*/1,
1238  /*N=*/aligned_intermediate_dim,
1239  /*K=*/aligned_embed_dim);
1240 
1241  gemm_blocked_serial(swiglu_row, w2, b2, output_row,
1242  /*M=*/1,
1243  /*N=*/aligned_embed_dim,
1244  /*K=*/aligned_intermediate_dim);
1245 }
void gemm_swiglu_fused(const float *x, const float *W_gate, const float *W_up, const float *b_gate, const float *b_up, float *output, int M, int N, int K)

References gemm_blocked_serial(), and gemm_swiglu_fused().

◆ ck_mlp_swiglu_forward_q4_k()

static void ck_mlp_swiglu_forward_q4_k ( const float *  input,
const void *  w1,
const float *  b1,
const void *  w2,
const float *  b2,
float *  fc1_out,
float *  swiglu_out,
float *  output,
int  tokens,
int  aligned_embed_dim,
int  aligned_intermediate_dim 
)
static

Definition at line 589 of file ckernel_orchestration.c.

600 {
601  int up_dim = 2 * aligned_intermediate_dim;
602  gemm_nt_q4_k(input, w1, b1, fc1_out,
603  tokens, up_dim, aligned_embed_dim);
604 
605  swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
606 
607  gemm_nt_q4_k(swiglu_out, w2, b2, output,
608  tokens, aligned_embed_dim, aligned_intermediate_dim);
609 }

References gemm_nt_q4_k(), and swiglu_forward().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), and ck_layer_forward_rmsnorm_swiglu_q4_k().

◆ ck_mlp_swiglu_forward_q4_k_q8_k()

static void ck_mlp_swiglu_forward_q4_k_q8_k ( const float *  input,
const void *  w1,
const float *  b1,
const void *  w2,
const float *  b2,
float *  fc1_out,
float *  swiglu_out,
float *  output,
int  aligned_embed_dim,
int  aligned_intermediate_dim 
)
static

Definition at line 635 of file ckernel_orchestration.c.

645 {
646  if (!input || !w1 || !w2 || !fc1_out || !swiglu_out || !output) {
647  return;
648  }
649  if ((aligned_embed_dim % QK_K) != 0 || (aligned_intermediate_dim % QK_K) != 0) {
650  return;
651  }
652 
653  const int up_dim = 2 * aligned_intermediate_dim;
654  const int q8_blocks_embed = aligned_embed_dim / QK_K;
655  const int q8_blocks_inter = aligned_intermediate_dim / QK_K;
656  const int q8_blocks_max = (q8_blocks_embed > q8_blocks_inter) ? q8_blocks_embed : q8_blocks_inter;
657  block_q8_K q8_buf[q8_blocks_max];
658 
659  quantize_row_q8_k(input, q8_buf, aligned_embed_dim);
660  gemm_nt_q4_k_q8_k(q8_buf, w1, b1, fc1_out,
661  /*M=*/1, /*N=*/up_dim, /*K=*/aligned_embed_dim);
662 
663  swiglu_forward(fc1_out, swiglu_out, /*tokens=*/1, aligned_intermediate_dim);
664 
665  quantize_row_q8_k(swiglu_out, q8_buf, aligned_intermediate_dim);
666  gemm_nt_q4_k_q8_k(q8_buf, w2, b2, output,
667  /*M=*/1, /*N=*/aligned_embed_dim, /*K=*/aligned_intermediate_dim);
668 }

References gemm_nt_q4_k_q8_k(), QK_K, quantize_row_q8_k(), and swiglu_forward().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_q4_k().

◆ ck_mlp_swiglu_forward_q4_k_q8_k_prefill()

static void ck_mlp_swiglu_forward_q4_k_q8_k_prefill ( const float *  input,
const void *  w1,
const float *  b1,
const void *  w2,
const float *  b2,
float *  fc1_out,
float *  swiglu_out,
float *  output,
int  tokens,
int  aligned_embed_dim,
int  aligned_intermediate_dim 
)
static

Definition at line 1808 of file ckernel_orchestration.c.

1819 {
1820  if (!input || !w1 || !w2 || !fc1_out || !swiglu_out || !output) {
1821  return;
1822  }
1823  if (tokens <= 0) {
1824  return;
1825  }
1826  if ((aligned_embed_dim % QK_K) != 0 || (aligned_intermediate_dim % QK_K) != 0) {
1827  return;
1828  }
1829 
1830  const int up_dim = 2 * aligned_intermediate_dim;
1831  const int q8_blocks_embed = aligned_embed_dim / QK_K;
1832  const int q8_blocks_inter = aligned_intermediate_dim / QK_K;
1833  const int q8_blocks_max = (q8_blocks_embed > q8_blocks_inter) ? q8_blocks_embed : q8_blocks_inter;
1834  block_q8_K q8_buf[q8_blocks_max];
1835 
1836  for (int t = 0; t < tokens; ++t) {
1837  const float *input_row = input + (size_t)t * (size_t)aligned_embed_dim;
1838  float *fc1_row = fc1_out + (size_t)t * (size_t)up_dim;
1839 
1840  quantize_row_q8_k(input_row, q8_buf, aligned_embed_dim);
1841  gemm_nt_q4_k_q8_k(q8_buf, w1, b1, fc1_row,
1842  /*M=*/1, /*N=*/up_dim, /*K=*/aligned_embed_dim);
1843  }
1844 
1845  swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
1846 
1847  for (int t = 0; t < tokens; ++t) {
1848  const float *swiglu_row = swiglu_out + (size_t)t * (size_t)aligned_intermediate_dim;
1849  float *out_row = output + (size_t)t * (size_t)aligned_embed_dim;
1850 
1851  quantize_row_q8_k(swiglu_row, q8_buf, aligned_intermediate_dim);
1852  gemm_nt_q4_k_q8_k(q8_buf, w2, b2, out_row,
1853  /*M=*/1, /*N=*/aligned_embed_dim, /*K=*/aligned_intermediate_dim);
1854  }
1855 }

References gemm_nt_q4_k_q8_k(), QK_K, quantize_row_q8_k(), and swiglu_forward().

Referenced by ck_layer_forward_rmsnorm_swiglu_q4_k().

◆ ck_mlp_swiglu_forward_quant()

static void ck_mlp_swiglu_forward_quant ( const float *  input,
const void *  w1,
const float *  b1,
CKDataType  w1_dtype,
const void *  w2,
const float *  b2,
CKDataType  w2_dtype,
float *  fc1_out,
float *  swiglu_out,
float *  output,
int  tokens,
int  aligned_embed_dim,
int  aligned_intermediate_dim 
)
static

Definition at line 611 of file ckernel_orchestration.c.

624 {
625  int up_dim = 2 * aligned_intermediate_dim;
626  ck_gemm_nt_quant(input, w1, b1, fc1_out,
627  tokens, up_dim, aligned_embed_dim, w1_dtype);
628 
629  swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
630 
631  ck_gemm_nt_quant(swiglu_out, w2, b2, output,
632  tokens, aligned_embed_dim, aligned_intermediate_dim, w2_dtype);
633 }

References ck_gemm_nt_quant(), and swiglu_forward().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_quant(), and ck_layer_forward_rmsnorm_swiglu_quant().

◆ ck_mlp_swiglu_forward_ref()

static void ck_mlp_swiglu_forward_ref ( const float *  input,
const float *  w1,
const float *  b1,
const float *  w2,
const float *  b2,
float *  fc1_out,
float *  swiglu_out,
float *  output,
int  tokens,
int  aligned_embed_dim,
int  aligned_intermediate_dim 
)
static

Definition at line 974 of file ckernel_orchestration.c.

985 {
986  int up_dim = 2 * aligned_intermediate_dim;
987  gemm_naive_parallel(input, w1, b1, fc1_out,
988  tokens, up_dim, aligned_embed_dim);
989 
990  swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
991 
992  gemm_naive_parallel(swiglu_out, w2, b2, output,
993  tokens, aligned_embed_dim, aligned_intermediate_dim);
994 }

References gemm_naive_parallel(), and swiglu_forward().

Referenced by ck_layer_forward_rmsnorm_swiglu_ref().

◆ ck_q8k_activations_enabled()

static int ck_q8k_activations_enabled ( void  )
static

Definition at line 314 of file ckernel_orchestration.c.

315 {
316  static int cached = -2;
317  if (cached != -2) {
318  return cached;
319  }
320 
321  const char *env = getenv("CK_Q8K_ACTIVATIONS");
322  if (!env || !env[0]) {
323  cached = ck_strict_parity_enabled() ? 0 : 1;
324  return cached;
325  }
326  if (env[0] == '0' || env[0] == 'n' || env[0] == 'N' ||
327  env[0] == 'f' || env[0] == 'F') {
328  cached = 0;
329  } else {
330  cached = 1;
331  }
332  return cached;
333 }
int ck_strict_parity_enabled(void)

References ck_strict_parity_enabled().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), and ck_layer_forward_rmsnorm_swiglu_q4_k().

◆ ck_qkv_project_head_major()

void ck_qkv_project_head_major ( const float *  input,
const float *  wq,
const float *  bq,
const float *  wk,
const float *  bk,
const float *  wv,
const float *  bv,
float *  q,
float *  k,
float *  v,
int  tokens,
int  kv_stride_tokens,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim 
)

Definition at line 168 of file ckernel_orchestration.c.

179 {
180  if (!input || !wq || !wk || !wv || !q || !k || !v) {
181  return;
182  }
183  if (kv_stride_tokens < tokens) {
184  return;
185  }
186 
187  size_t head_weight_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
188  size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
189  size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
190 
191  for (int h = 0; h < num_heads; ++h) {
192  const float *wq_h = wq + (size_t)h * head_weight_stride;
193  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
194  float *q_h = q + (size_t)h * q_head_stride;
195 
196  gemm_blocked_serial(input, wq_h, bq_h, q_h,
197  tokens, aligned_head_dim, aligned_embed_dim);
198  }
199 
200  for (int h = 0; h < num_kv_heads; ++h) {
201  const float *wk_h = wk + (size_t)h * head_weight_stride;
202  const float *wv_h = wv + (size_t)h * head_weight_stride;
203 
204  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
205  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
206 
207  float *k_h = k + (size_t)h * kv_head_stride;
208  float *v_h = v + (size_t)h * kv_head_stride;
209 
210  gemm_blocked_serial(input, wk_h, bk_h, k_h,
211  tokens, aligned_head_dim, aligned_embed_dim);
212  gemm_blocked_serial(input, wv_h, bv_h, v_h,
213  tokens, aligned_head_dim, aligned_embed_dim);
214  }
215 }

References gemm_blocked_serial().

Referenced by ck_layer_forward_rmsnorm_swiglu().

◆ ck_qkv_project_head_major_backward()

void ck_qkv_project_head_major_backward ( const float *  d_q,
const float *  d_k,
const float *  d_v,
const float *  input,
const float *  wq,
const float *  bq,
const float *  wk,
const float *  bk,
const float *  wv,
const float *  bv,
float *  d_input,
float *  d_wq,
float *  d_bq,
float *  d_wk,
float *  d_bk,
float *  d_wv,
float *  d_bv,
float *  scratch,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim,
int  num_threads 
)

Definition at line 856 of file ckernel_orchestration.c.

880 {
881  if (!d_q || !d_k || !d_v || !input || !wq || !wk || !wv ||
882  !d_input || !d_wq || !d_bq || !d_wk || !d_bk || !d_wv || !d_bv || !scratch) {
883  return;
884  }
885 
886  size_t total_in = (size_t)tokens * (size_t)aligned_embed_dim;
887  for (size_t i = 0; i < total_in; ++i) {
888  d_input[i] = 0.0f;
889  }
890 
891  size_t head_weight_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
892  size_t head_out_stride = (size_t)tokens * (size_t)aligned_head_dim;
893 
894  for (int h = 0; h < num_heads; ++h) {
895  const float *d_q_h = d_q + (size_t)h * head_out_stride;
896  const float *wq_h = wq + (size_t)h * head_weight_stride;
897  float *d_wq_h = d_wq + (size_t)h * head_weight_stride;
898  float *d_bq_h = d_bq + (size_t)h * (size_t)aligned_head_dim;
899 
900  fc2_backward_kernel(d_q_h,
901  input,
902  wq_h,
903  scratch,
904  d_wq_h,
905  d_bq_h,
906  tokens,
907  aligned_embed_dim,
908  aligned_head_dim,
909  num_threads);
910  ck_add_inplace(d_input, scratch, tokens, aligned_embed_dim);
911  }
912 
913  for (int h = 0; h < num_kv_heads; ++h) {
914  const float *d_k_h = d_k + (size_t)h * head_out_stride;
915  const float *d_v_h = d_v + (size_t)h * head_out_stride;
916 
917  const float *wk_h = wk + (size_t)h * head_weight_stride;
918  const float *wv_h = wv + (size_t)h * head_weight_stride;
919 
920  float *d_wk_h = d_wk + (size_t)h * head_weight_stride;
921  float *d_wv_h = d_wv + (size_t)h * head_weight_stride;
922 
923  float *d_bk_h = d_bk + (size_t)h * (size_t)aligned_head_dim;
924  float *d_bv_h = d_bv + (size_t)h * (size_t)aligned_head_dim;
925 
926  fc2_backward_kernel(d_k_h,
927  input,
928  wk_h,
929  scratch,
930  d_wk_h,
931  d_bk_h,
932  tokens,
933  aligned_embed_dim,
934  aligned_head_dim,
935  num_threads);
936  ck_add_inplace(d_input, scratch, tokens, aligned_embed_dim);
937 
938  fc2_backward_kernel(d_v_h,
939  input,
940  wv_h,
941  scratch,
942  d_wv_h,
943  d_bv_h,
944  tokens,
945  aligned_embed_dim,
946  aligned_head_dim,
947  num_threads);
948  ck_add_inplace(d_input, scratch, tokens, aligned_embed_dim);
949  }
950 }

References ck_add_inplace(), and fc2_backward_kernel().

Referenced by ck_layer_backward_rmsnorm_swiglu().

◆ ck_qkv_project_head_major_q4_k()

static void ck_qkv_project_head_major_q4_k ( const float *  input,
const void *  wq,
const float *  bq,
const void *  wk,
const float *  bk,
const void *  wv,
const float *  bv,
float *  q,
float *  k,
float *  v,
int  tokens,
int  kv_stride_tokens,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim 
)
static

Definition at line 387 of file ckernel_orchestration.c.

398 {
399  if (!input || !wq || !wk || !wv || !q || !k || !v) {
400  return;
401  }
402  if (kv_stride_tokens < tokens) {
403  return;
404  }
405 
406  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
407  const size_t head_w_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
408  const size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
409  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
410 
411  const uint8_t *wq_bytes = (const uint8_t *)wq;
412  const uint8_t *wk_bytes = (const uint8_t *)wk;
413  const uint8_t *wv_bytes = (const uint8_t *)wv;
414 
415  for (int h = 0; h < num_heads; ++h) {
416  const void *wq_h = wq_bytes + (size_t)h * head_w_bytes;
417  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
418  float *q_h = q + (size_t)h * q_head_stride;
419 
420  gemm_nt_q4_k(input, wq_h, bq_h, q_h,
421  tokens, aligned_head_dim, aligned_embed_dim);
422  }
423 
424  for (int h = 0; h < num_kv_heads; ++h) {
425  const void *wk_h = wk_bytes + (size_t)h * head_w_bytes;
426  const void *wv_h = wv_bytes + (size_t)h * head_w_bytes;
427 
428  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
429  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
430 
431  float *k_h = k + (size_t)h * kv_head_stride;
432  float *v_h = v + (size_t)h * kv_head_stride;
433 
434  gemm_nt_q4_k(input, wk_h, bk_h, k_h,
435  tokens, aligned_head_dim, aligned_embed_dim);
436  gemm_nt_q4_k(input, wv_h, bv_h, v_h,
437  tokens, aligned_head_dim, aligned_embed_dim);
438  }
439 }
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.

References CK_DT_Q4_K, ck_dtype_row_bytes(), and gemm_nt_q4_k().

Referenced by ck_layer_forward_rmsnorm_swiglu_q4_k().

◆ ck_qkv_project_head_major_q4_k_q8_k()

static void ck_qkv_project_head_major_q4_k_q8_k ( const float *  input,
const void *  wq,
const float *  bq,
const void *  wk,
const float *  bk,
const void *  wv,
const float *  bv,
float *  q,
float *  k,
float *  v,
int  tokens,
int  kv_stride_tokens,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim 
)
static

Definition at line 1694 of file ckernel_orchestration.c.

1705 {
1706  if (!input || !wq || !wk || !wv || !q || !k || !v) {
1707  return;
1708  }
1709  if (tokens <= 0 || aligned_embed_dim <= 0) {
1710  return;
1711  }
1712  if (kv_stride_tokens < tokens) {
1713  return;
1714  }
1715  if ((aligned_embed_dim % QK_K) != 0) {
1716  return;
1717  }
1718 
1719  const int q8_blocks = aligned_embed_dim / QK_K;
1720  block_q8_K q8_buf[q8_blocks];
1721  const size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
1722  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
1723 
1724  float q_token[num_heads * aligned_head_dim];
1725  float k_token[num_kv_heads * aligned_head_dim];
1726  float v_token[num_kv_heads * aligned_head_dim];
1727 
1728  for (int t = 0; t < tokens; ++t) {
1729  const float *input_row = input + (size_t)t * (size_t)aligned_embed_dim;
1730  quantize_row_q8_k(input_row, q8_buf, aligned_embed_dim);
1731 
1733  wq, bq,
1734  wk, bk,
1735  wv, bv,
1736  q_token,
1737  k_token,
1738  v_token,
1739  aligned_embed_dim,
1740  num_heads,
1741  num_kv_heads,
1742  aligned_head_dim);
1743 
1744  for (int h = 0; h < num_heads; ++h) {
1745  float *q_dst = q + (size_t)h * q_head_stride + (size_t)t * (size_t)aligned_head_dim;
1746  memcpy(q_dst,
1747  q_token + (size_t)h * (size_t)aligned_head_dim,
1748  (size_t)aligned_head_dim * sizeof(float));
1749  }
1750 
1751  for (int h = 0; h < num_kv_heads; ++h) {
1752  float *k_dst = k + (size_t)h * kv_head_stride + (size_t)t * (size_t)aligned_head_dim;
1753  float *v_dst = v + (size_t)h * kv_head_stride + (size_t)t * (size_t)aligned_head_dim;
1754  memcpy(k_dst,
1755  k_token + (size_t)h * (size_t)aligned_head_dim,
1756  (size_t)aligned_head_dim * sizeof(float));
1757  memcpy(v_dst,
1758  v_token + (size_t)h * (size_t)aligned_head_dim,
1759  (size_t)aligned_head_dim * sizeof(float));
1760  }
1761  }
1762 }

References ck_qkv_project_head_major_token_q4_k_q8_k(), QK_K, and quantize_row_q8_k().

Referenced by ck_layer_forward_rmsnorm_swiglu_q4_k().

◆ ck_qkv_project_head_major_quant()

static void ck_qkv_project_head_major_quant ( const float *  input,
const void *  wq,
const float *  bq,
CKDataType  wq_dtype,
const void *  wk,
const float *  bk,
CKDataType  wk_dtype,
const void *  wv,
const float *  bv,
CKDataType  wv_dtype,
float *  q,
float *  k,
float *  v,
int  tokens,
int  kv_stride_tokens,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim 
)
static

Definition at line 441 of file ckernel_orchestration.c.

452 {
453  if (!input || !wq || !wk || !wv || !q || !k || !v) {
454  return;
455  }
456  if (kv_stride_tokens < tokens) {
457  return;
458  }
459 
460  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
461  const size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
462  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
463 
464  const size_t wq_head_bytes = ck_dtype_row_bytes(wq_dtype, head_w_elems);
465  const size_t wk_head_bytes = ck_dtype_row_bytes(wk_dtype, head_w_elems);
466  const size_t wv_head_bytes = ck_dtype_row_bytes(wv_dtype, head_w_elems);
467 
468  const uint8_t *wq_bytes = (const uint8_t *)wq;
469  const uint8_t *wk_bytes = (const uint8_t *)wk;
470  const uint8_t *wv_bytes = (const uint8_t *)wv;
471 
472  for (int h = 0; h < num_heads; ++h) {
473  const void *wq_h = (wq_dtype == CK_DT_FP32)
474  ? (const void *)((const float *)wq + (size_t)h * head_w_elems)
475  : (const void *)(wq_bytes + (size_t)h * wq_head_bytes);
476  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
477  float *q_h = q + (size_t)h * q_head_stride;
478 
479  ck_gemm_nt_quant(input, wq_h, bq_h, q_h,
480  tokens, aligned_head_dim, aligned_embed_dim, wq_dtype);
481  }
482 
483  for (int h = 0; h < num_kv_heads; ++h) {
484  const void *wk_h = (wk_dtype == CK_DT_FP32)
485  ? (const void *)((const float *)wk + (size_t)h * head_w_elems)
486  : (const void *)(wk_bytes + (size_t)h * wk_head_bytes);
487  const void *wv_h = (wv_dtype == CK_DT_FP32)
488  ? (const void *)((const float *)wv + (size_t)h * head_w_elems)
489  : (const void *)(wv_bytes + (size_t)h * wv_head_bytes);
490 
491  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
492  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
493 
494  float *k_h = k + (size_t)h * kv_head_stride;
495  float *v_h = v + (size_t)h * kv_head_stride;
496 
497  ck_gemm_nt_quant(input, wk_h, bk_h, k_h,
498  tokens, aligned_head_dim, aligned_embed_dim, wk_dtype);
499  ck_gemm_nt_quant(input, wv_h, bv_h, v_h,
500  tokens, aligned_head_dim, aligned_embed_dim, wv_dtype);
501  }
502 }

References CK_DT_FP32, ck_dtype_row_bytes(), and ck_gemm_nt_quant().

Referenced by ck_layer_forward_rmsnorm_swiglu_quant().

◆ ck_qkv_project_head_major_ref()

static void ck_qkv_project_head_major_ref ( const float *  input,
const float *  wq,
const float *  bq,
const float *  wk,
const float *  bk,
const float *  wv,
const float *  bv,
float *  q,
float *  k,
float *  v,
int  tokens,
int  kv_stride_tokens,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim 
)
static

Definition at line 670 of file ckernel_orchestration.c.

681 {
682  if (!input || !wq || !wk || !wv || !q || !k || !v) {
683  return;
684  }
685  if (kv_stride_tokens < tokens) {
686  return;
687  }
688 
689  size_t head_weight_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
690  size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
691  size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
692 
693  for (int h = 0; h < num_heads; ++h) {
694  const float *wq_h = wq + (size_t)h * head_weight_stride;
695  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
696  float *q_h = q + (size_t)h * q_head_stride;
697 
698  gemm_naive_parallel(input, wq_h, bq_h, q_h,
699  tokens, aligned_head_dim, aligned_embed_dim);
700  }
701 
702  for (int h = 0; h < num_kv_heads; ++h) {
703  const float *wk_h = wk + (size_t)h * head_weight_stride;
704  const float *wv_h = wv + (size_t)h * head_weight_stride;
705 
706  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
707  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
708 
709  float *k_h = k + (size_t)h * kv_head_stride;
710  float *v_h = v + (size_t)h * kv_head_stride;
711 
712  gemm_naive_parallel(input, wk_h, bk_h, k_h,
713  tokens, aligned_head_dim, aligned_embed_dim);
714  gemm_naive_parallel(input, wv_h, bv_h, v_h,
715  tokens, aligned_head_dim, aligned_embed_dim);
716  }
717 }

References gemm_naive_parallel().

Referenced by ck_layer_forward_rmsnorm_swiglu_ref().

◆ ck_qkv_project_head_major_token_q4_k()

static void ck_qkv_project_head_major_token_q4_k ( const float *  input_row,
const void *  wq,
const float *  bq,
const void *  wk,
const float *  bk,
const void *  wv,
const float *  bv,
float *  q_token,
float *  k_token,
float *  v_token,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim 
)
static

Definition at line 1604 of file ckernel_orchestration.c.

1615 {
1616  if (!input_row || !wq || !wk || !wv || !q_token || !k_token || !v_token) {
1617  return;
1618  }
1619 
1620  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1621  const size_t head_w_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1622 
1623  const uint8_t *wq_bytes = (const uint8_t *)wq;
1624  const uint8_t *wk_bytes = (const uint8_t *)wk;
1625  const uint8_t *wv_bytes = (const uint8_t *)wv;
1626 
1627  for (int h = 0; h < num_heads; ++h) {
1628  const void *wq_h = wq_bytes + (size_t)h * head_w_bytes;
1629  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
1630  float *q_h = q_token + (size_t)h * (size_t)aligned_head_dim;
1631  gemm_nt_q4_k(input_row, wq_h, bq_h, q_h,
1632  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim);
1633  }
1634 
1635  for (int h = 0; h < num_kv_heads; ++h) {
1636  const void *wk_h = wk_bytes + (size_t)h * head_w_bytes;
1637  const void *wv_h = wv_bytes + (size_t)h * head_w_bytes;
1638  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
1639  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
1640  float *k_h = k_token + (size_t)h * (size_t)aligned_head_dim;
1641  float *v_h = v_token + (size_t)h * (size_t)aligned_head_dim;
1642  gemm_nt_q4_k(input_row, wk_h, bk_h, k_h,
1643  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim);
1644  gemm_nt_q4_k(input_row, wv_h, bv_h, v_h,
1645  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim);
1646  }
1647 }

References CK_DT_Q4_K, ck_dtype_row_bytes(), and gemm_nt_q4_k().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_q4_k().

◆ ck_qkv_project_head_major_token_q4_k_q8_k()

static void ck_qkv_project_head_major_token_q4_k_q8_k ( const block_q8_K input_q8,
const void *  wq,
const float *  bq,
const void *  wk,
const float *  bk,
const void *  wv,
const float *  bv,
float *  q_token,
float *  k_token,
float *  v_token,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim 
)
static

Definition at line 1649 of file ckernel_orchestration.c.

1660 {
1661  if (!input_q8 || !wq || !wk || !wv || !q_token || !k_token || !v_token) {
1662  return;
1663  }
1664 
1665  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1666  const size_t head_w_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, head_w_elems);
1667 
1668  const uint8_t *wq_bytes = (const uint8_t *)wq;
1669  const uint8_t *wk_bytes = (const uint8_t *)wk;
1670  const uint8_t *wv_bytes = (const uint8_t *)wv;
1671 
1672  for (int h = 0; h < num_heads; ++h) {
1673  const void *wq_h = wq_bytes + (size_t)h * head_w_bytes;
1674  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
1675  float *q_h = q_token + (size_t)h * (size_t)aligned_head_dim;
1676  gemm_nt_q4_k_q8_k(input_q8, wq_h, bq_h, q_h,
1677  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim);
1678  }
1679 
1680  for (int h = 0; h < num_kv_heads; ++h) {
1681  const void *wk_h = wk_bytes + (size_t)h * head_w_bytes;
1682  const void *wv_h = wv_bytes + (size_t)h * head_w_bytes;
1683  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
1684  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
1685  float *k_h = k_token + (size_t)h * (size_t)aligned_head_dim;
1686  float *v_h = v_token + (size_t)h * (size_t)aligned_head_dim;
1687  gemm_nt_q4_k_q8_k(input_q8, wk_h, bk_h, k_h,
1688  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim);
1689  gemm_nt_q4_k_q8_k(input_q8, wv_h, bv_h, v_h,
1690  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim);
1691  }
1692 }

References CK_DT_Q4_K, ck_dtype_row_bytes(), and gemm_nt_q4_k_q8_k().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), and ck_qkv_project_head_major_q4_k_q8_k().

◆ ck_qkv_project_head_major_token_quant()

static void ck_qkv_project_head_major_token_quant ( const float *  input_row,
const void *  wq,
const float *  bq,
CKDataType  wq_dtype,
const void *  wk,
const float *  bk,
CKDataType  wk_dtype,
const void *  wv,
const float *  bv,
CKDataType  wv_dtype,
float *  q_token,
float *  k_token,
float *  v_token,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim 
)
static

Definition at line 1857 of file ckernel_orchestration.c.

1868 {
1869  if (!input_row || !wq || !wk || !wv || !q_token || !k_token || !v_token) {
1870  return;
1871  }
1872 
1873  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
1874  const size_t wq_head_bytes = ck_dtype_row_bytes(wq_dtype, head_w_elems);
1875  const size_t wk_head_bytes = ck_dtype_row_bytes(wk_dtype, head_w_elems);
1876  const size_t wv_head_bytes = ck_dtype_row_bytes(wv_dtype, head_w_elems);
1877 
1878  const uint8_t *wq_bytes = (const uint8_t *)wq;
1879  const uint8_t *wk_bytes = (const uint8_t *)wk;
1880  const uint8_t *wv_bytes = (const uint8_t *)wv;
1881 
1882  for (int h = 0; h < num_heads; ++h) {
1883  const void *wq_h = (wq_dtype == CK_DT_FP32)
1884  ? (const void *)((const float *)wq + (size_t)h * head_w_elems)
1885  : (const void *)(wq_bytes + (size_t)h * wq_head_bytes);
1886  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
1887  float *q_h = q_token + (size_t)h * (size_t)aligned_head_dim;
1888  ck_gemm_nt_quant(input_row, wq_h, bq_h, q_h,
1889  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim, wq_dtype);
1890  }
1891 
1892  for (int h = 0; h < num_kv_heads; ++h) {
1893  const void *wk_h = (wk_dtype == CK_DT_FP32)
1894  ? (const void *)((const float *)wk + (size_t)h * head_w_elems)
1895  : (const void *)(wk_bytes + (size_t)h * wk_head_bytes);
1896  const void *wv_h = (wv_dtype == CK_DT_FP32)
1897  ? (const void *)((const float *)wv + (size_t)h * head_w_elems)
1898  : (const void *)(wv_bytes + (size_t)h * wv_head_bytes);
1899  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
1900  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
1901  float *k_h = k_token + (size_t)h * (size_t)aligned_head_dim;
1902  float *v_h = v_token + (size_t)h * (size_t)aligned_head_dim;
1903  ck_gemm_nt_quant(input_row, wk_h, bk_h, k_h,
1904  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim, wk_dtype);
1905  ck_gemm_nt_quant(input_row, wv_h, bv_h, v_h,
1906  /*M=*/1, /*N=*/aligned_head_dim, /*K=*/aligned_embed_dim, wv_dtype);
1907  }
1908 }

References CK_DT_FP32, ck_dtype_row_bytes(), and ck_gemm_nt_quant().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_quant().

◆ ck_residual_add_backward()

void ck_residual_add_backward ( const float *  d_out,
float *  d_a,
float *  d_b,
int  tokens,
int  aligned_embed_dim 
)

Definition at line 151 of file ckernel_orchestration.c.

156 {
157  if (!d_out || !d_a || !d_b) {
158  return;
159  }
160  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
161  for (size_t i = 0; i < total; ++i) {
162  float v = d_out[i];
163  d_a[i] = v;
164  d_b[i] = v;
165  }
166 }

Referenced by ck_layer_backward_rmsnorm_swiglu().

◆ ck_residual_add_token_major()

void ck_residual_add_token_major ( const float *  a,
const float *  b,
float *  out,
int  tokens,
int  aligned_embed_dim 
)