← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_orchestration.h File Reference
#include <stddef.h>
#include "ckernel_dtype.h"

Go to the source code of this file.

Data Structures

struct  CKLayerBackwardParams
 
struct  CKLayerForwardParams
 
struct  CKLayerForwardParamsQ4K
 

Functions

void ck_attention_project_head_major (const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
void ck_attention_project_head_major_backward (const float *d_out, const float *attn_out, const float *wo, float *d_attn_out, float *d_wo, float *d_bo, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
void ck_attention_project_head_major_decode_token (const float *attn_token, const float *wo, const float *bo, float *out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)
 
void ck_gemm_nt_quant (const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype)
 
void ck_layer_backward_rmsnorm_swiglu (const CKLayerBackwardParams *p)
 
void ck_layer_forward_rmsnorm_swiglu (const CKLayerForwardParams *p)
 
void ck_layer_forward_rmsnorm_swiglu_decode (const CKLayerForwardParams *p, int token_index, int cache_capacity)
 
void ck_layer_forward_rmsnorm_swiglu_decode_fused (const CKLayerForwardParams *p, int token_index, int cache_capacity)
 
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn (const CKLayerForwardParams *p, int token_index, int cache_capacity)
 
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp (const CKLayerForwardParams *p, int token_index, int cache_capacity)
 
void ck_layer_forward_rmsnorm_swiglu_decode_q4_k (const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity)
 
void ck_layer_forward_rmsnorm_swiglu_decode_quant (const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity)
 
void ck_layer_forward_rmsnorm_swiglu_q4_k (const CKLayerForwardParamsQ4K *p)
 
void ck_layer_forward_rmsnorm_swiglu_quant (const CKLayerForwardParamsQ4K *p)
 
void ck_layer_forward_rmsnorm_swiglu_ref (const CKLayerForwardParams *p)
 
void ck_mlp_swiglu_forward (const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
 
void ck_mlp_swiglu_forward_fully_fused_token (const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
 
void ck_mlp_swiglu_forward_fused_token (const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *swiglu_row, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
 
void ck_qkv_project_head_major (const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
 
void ck_qkv_project_head_major_backward (const float *d_q, const float *d_k, const float *d_v, const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *d_input, float *d_wq, float *d_bq, float *d_wk, float *d_bk, float *d_wv, float *d_bv, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim, int num_threads)
 
void ck_qkv_project_head_major_token (const float *input_row, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
 
void ck_residual_add_backward (const float *d_out, float *d_a, float *d_b, int tokens, int aligned_embed_dim)
 
void ck_residual_add_token_major (const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
 

Detailed Description


LEGACY HEADER - NOT USED IN v6.6

This header declares v6.5 orchestration functions that are NO LONGER USED. v6.6 uses IR Lower 3 + codegen instead of hardcoded orchestration.

v6.6 Architecture (REPLACEMENT):

  • Kernel dispatch: version/v6.6/scripts/build_ir_v6_6.py + ckernel_codegen.c
  • Memory planning: version/v6.6/scripts/memory_planner_v6_6.py
  • Registry: version/v6.6/kernel_maps/KERNEL_REGISTRY.json
  • Kernel bindings: version/v6.6/kernel_maps/kernel_bindings.json

Deprecated functions (NOT used in v6.6):

  • ck_layer_forward_rmsnorm_swiglu* -> IR Lower 3 + mega_fused_* kernels
  • ck_qkv_project_head_major* -> q_proj/k_proj/v_proj ops in IR
  • ck_attention_project_head_major* -> out_proj op in IR
  • ck_mlp_swiglu_forward* -> mlp_gate_up/mlp_down ops in IR
  • ck_gemm_nt_quant -> KERNEL_REGISTRY.json dispatch
  • ck_residual_add_token_major -> residual_add op in IR

To remove completely:

  1. Delete this header
  2. Delete ckernel_orchestration.c
  3. Remove from Makefile SRCS list

Last used: v6.5

Deprecated: v6.6 (2026-02)

Definition in file ckernel_orchestration.h.

Function Documentation

◆ ck_attention_project_head_major()

void ck_attention_project_head_major ( const float *  attn_out,
const float *  wo,
const float *  bo,
float *  out,
float *  scratch,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)

Definition at line 730 of file ckernel_orchestration.c.

739 {
740  if (!attn_out || !wo || !out) {
741  return;
742  }
743  if (num_heads > 1 && !scratch) {
744  return;
745  }
746 
747  size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
748  size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
749 
750  for (int h = 0; h < num_heads; ++h) {
751  const float *head_in = attn_out + (size_t)h * head_in_stride;
752  const float *wo_h = wo + (size_t)h * head_weight_stride;
753 
754  if (h == 0) {
755  gemm_blocked_serial(head_in, wo_h, bo, out,
756  tokens, aligned_embed_dim, aligned_head_dim);
757  } else {
758  gemm_blocked_serial(head_in, wo_h, NULL, scratch,
759  tokens, aligned_embed_dim, aligned_head_dim);
760  ck_add_inplace(out, scratch, tokens, aligned_embed_dim);
761  }
762  }
763 }
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:661
static void ck_add_inplace(float *dst, const float *src, int tokens, int aligned_embed_dim)

References ck_add_inplace(), and gemm_blocked_serial().

Referenced by ck_attention_project_head_major_quant(), and ck_layer_forward_rmsnorm_swiglu().

◆ ck_attention_project_head_major_backward()

void ck_attention_project_head_major_backward ( const float *  d_out,
const float *  attn_out,
const float *  wo,
float *  d_attn_out,
float *  d_wo,
float *  d_bo,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)

Definition at line 800 of file ckernel_orchestration.c.

810 {
811  if (!d_out || !attn_out || !wo || !d_attn_out || !d_wo || !d_bo) {
812  return;
813  }
814 
815  // Bias gradient: sum over tokens once (bias is applied once in forward).
816  for (int d = 0; d < aligned_embed_dim; ++d) {
817  d_bo[d] = 0.0f;
818  }
819  for (int t = 0; t < tokens; ++t) {
820  const float *row = d_out + (size_t)t * (size_t)aligned_embed_dim;
821  for (int d = 0; d < aligned_embed_dim; ++d) {
822  d_bo[d] += row[d];
823  }
824  }
825 
826  size_t head_in_stride = (size_t)tokens * (size_t)aligned_head_dim;
827  size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
828 
829  float *tmp_b = (float *)calloc((size_t)aligned_embed_dim, sizeof(float));
830  if (!tmp_b) {
831  return;
832  }
833 
834  for (int h = 0; h < num_heads; ++h) {
835  const float *head_in = attn_out + (size_t)h * head_in_stride;
836  const float *wo_h = wo + (size_t)h * head_weight_stride;
837  float *d_head_in = d_attn_out + (size_t)h * head_in_stride;
838  float *d_wo_h = d_wo + (size_t)h * head_weight_stride;
839 
840  memset(tmp_b, 0, (size_t)aligned_embed_dim * sizeof(float));
841  fc2_backward_kernel(d_out,
842  head_in,
843  wo_h,
844  d_head_in,
845  d_wo_h,
846  tmp_b,
847  tokens,
848  aligned_head_dim,
849  aligned_embed_dim,
850  1);
851  }
852 
853  free(tmp_b);
854 }
void fc2_backward_kernel(const float *d_output, const float *fc2_input, const float *W_fc2, float *d_input, float *d_W_fc2, float *d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads)
Definition: mlp_kernels.c:118

References fc2_backward_kernel().

Referenced by ck_layer_backward_rmsnorm_swiglu().

◆ ck_attention_project_head_major_decode_token()

void ck_attention_project_head_major_decode_token ( const float *  attn_token,
const float *  wo,
const float *  bo,
float *  out_token,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim 
)

Definition at line 115 of file attention_decode_fused.c.

123 {
124  const size_t head_in_stride = (size_t)aligned_head_dim;
125  const size_t head_weight_stride = (size_t)aligned_embed_dim * (size_t)aligned_head_dim;
126 
127 #pragma omp parallel for schedule(static)
128  for (int j = 0; j < embed_dim; ++j) {
129  float sum = bo ? bo[j] : 0.0f;
130  for (int h = 0; h < num_heads; ++h) {
131  const float *head_in = attn_token + (size_t)h * head_in_stride;
132  const float *wo_row = wo + (size_t)h * head_weight_stride + (size_t)j * (size_t)aligned_head_dim;
133  sum += ck_dot_f32(head_in, wo_row, aligned_head_dim);
134  }
135  out_token[j] = sum;
136  }
137 
138  for (int j = embed_dim; j < aligned_embed_dim; ++j) {
139  out_token[j] = 0.0f;
140  }
141 }
static float ck_dot_f32(const float *a, const float *b, int len)

References ck_dot_f32().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), and ck_layer_forward_rmsnorm_swiglu_decode_quant().

◆ ck_gemm_nt_quant()

void ck_gemm_nt_quant ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K,
CKDataType  dtype 
)

Definition at line 335 of file ckernel_orchestration.c.

341 {
342  switch (dtype) {
343  case CK_DT_FP32:
344  gemm_blocked_serial(A, (const float *)B, bias, C, M, N, K);
345  break;
346  case CK_DT_Q4_K:
347  gemm_nt_q4_k(A, B, bias, C, M, N, K);
348  break;
349  case CK_DT_Q6_K:
350  gemm_nt_q6_k(A, B, bias, C, M, N, K);
351  break;
352  case CK_DT_Q4_0:
353  gemm_nt_q4_0(A, B, bias, C, M, N, K);
354  break;
355  case CK_DT_Q4_1:
356  gemm_nt_q4_1(A, B, bias, C, M, N, K);
357  break;
358  case CK_DT_Q5_0:
359  gemm_nt_q5_0(A, B, bias, C, M, N, K);
360  break;
361  case CK_DT_Q5_1:
362  gemm_nt_q5_1(A, B, bias, C, M, N, K);
363  break;
364  case CK_DT_Q8_0:
365  gemm_nt_q8_0(A, B, bias, C, M, N, K);
366  break;
367  default:
368  break;
369  }
370 }
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
@ CK_DT_Q4_0
Definition: ckernel_dtype.h:38
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_Q5_0
Definition: ckernel_dtype.h:44
@ CK_DT_FP32
Definition: ckernel_dtype.h:29
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
@ CK_DT_Q4_1
Definition: ckernel_dtype.h:39
@ CK_DT_Q5_1
Definition: ckernel_dtype.h:45
void gemm_nt_q4_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q4_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q4_1 weights: C = A @ B^T.
void gemm_nt_q5_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q6_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q8_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void gemm_nt_q5_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_1 weights: C = A @ B^T.
#define C(color)
Definition: show_config.c:39

References C, CK_DT_FP32, CK_DT_Q4_0, CK_DT_Q4_1, CK_DT_Q4_K, CK_DT_Q5_0, CK_DT_Q5_1, CK_DT_Q6_K, CK_DT_Q8_0, gemm_blocked_serial(), gemm_nt_q4_0(), gemm_nt_q4_1(), gemm_nt_q4_k(), gemm_nt_q5_0(), gemm_nt_q5_1(), gemm_nt_q6_k(), and gemm_nt_q8_0().

Referenced by ck_attention_project_head_major_quant(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), ck_mlp_swiglu_forward_quant(), ck_qkv_project_head_major_quant(), ck_qkv_project_head_major_token_quant(), and mega_fused_attention_prefill().

◆ ck_layer_backward_rmsnorm_swiglu()

void ck_layer_backward_rmsnorm_swiglu ( const CKLayerBackwardParams p)

Definition at line 2677 of file ckernel_orchestration.c.

2678 {
2679  if (!p) {
2680  return;
2681  }
2682 
2683  int T = p->tokens;
2684  int aligned_embed = p->aligned_embed_dim;
2685  int aligned_head = p->aligned_head_dim;
2686  int aligned_intermediate = p->aligned_intermediate_dim;
2687  int up_dim = 2 * aligned_intermediate;
2688  int num_threads = 1;
2689 
2690  // 1) Residual add (output = residual1 + mlp_out)
2691  ck_residual_add_backward(p->d_output, p->d_residual1, p->d_mlp_out, T, aligned_embed);
2692 
2693  // 2) MLP down proj backward
2695  p->swiglu_out,
2696  p->w2,
2697  p->d_swiglu_out,
2698  p->d_w2,
2699  p->d_b2,
2700  T,
2701  aligned_intermediate,
2702  aligned_embed,
2703  num_threads);
2704 
2705  // 3) SwiGLU backward
2706  swiglu_backward(p->fc1_out, p->d_swiglu_out, p->d_fc1_out, T, aligned_intermediate);
2707 
2708  // 4) MLP up proj backward
2710  p->ln2_out,
2711  p->w1,
2712  p->d_ln2_out,
2713  p->d_w1,
2714  p->d_b1,
2715  T,
2716  aligned_embed,
2717  up_dim,
2718  num_threads);
2719 
2720  // 5) RMSNorm (ln2) backward; reuse d_output as scratch for d_residual1_from_ln2
2722  p->residual1,
2723  p->ln2_gamma,
2724  p->ln2_rstd,
2725  p->d_output,
2726  p->d_ln2_gamma,
2727  T,
2728  p->embed_dim,
2729  aligned_embed);
2730  ck_add_inplace(p->d_residual1, p->d_output, T, aligned_embed);
2731 
2732  // 6) Residual add (residual1 = input + proj_tmp)
2733  ck_residual_add_backward(p->d_residual1, p->d_input, p->d_proj_tmp, T, aligned_embed);
2734 
2735  // 7) Attention projection backward
2737  p->attn_out,
2738  p->wo,
2739  p->d_attn_out,
2740  p->d_wo,
2741  p->d_bo,
2742  T,
2743  aligned_embed,
2744  p->num_heads,
2745  aligned_head);
2746 
2747  // 8) Attention backward
2749  p->q,
2750  p->k,
2751  p->v,
2752  p->scores,
2753  p->d_q,
2754  p->d_k,
2755  p->d_v,
2756  p->d_scores,
2757  p->num_heads,
2758  p->num_kv_heads,
2759  T,
2760  p->head_dim,
2761  aligned_head,
2763 
2764  // 9) RoPE backward (if enabled)
2765  if (p->rope_cos && p->rope_sin) {
2766  rope_backward_qk(p->d_q,
2767  p->d_k,
2768  p->d_q,
2769  p->d_k,
2770  p->rope_cos,
2771  p->rope_sin,
2772  p->num_heads,
2773  p->num_kv_heads,
2774  T,
2775  p->head_dim,
2776  aligned_head,
2777  p->rope_pos_offset);
2778  }
2779 
2780  // 10) QKV projection backward (scratch uses d_proj_tmp)
2782  p->d_k,
2783  p->d_v,
2784  p->ln1_out,
2785  p->wq,
2786  p->bq,
2787  p->wk,
2788  p->bk,
2789  p->wv,
2790  p->bv,
2791  p->d_ln1_out,
2792  p->d_wq,
2793  p->d_bq,
2794  p->d_wk,
2795  p->d_bk,
2796  p->d_wv,
2797  p->d_bv,
2798  p->d_proj_tmp,
2799  T,
2800  aligned_embed,
2801  p->num_heads,
2802  p->num_kv_heads,
2803  aligned_head,
2804  num_threads);
2805 
2806  // 11) RMSNorm (ln1) backward; reuse d_ln1_out as scratch for d_input_from_ln1
2808  p->input,
2809  p->ln1_gamma,
2810  p->ln1_rstd,
2811  p->d_ln1_out,
2812  p->d_ln1_gamma,
2813  T,
2814  p->embed_dim,
2815  aligned_embed);
2816  ck_add_inplace(p->d_input, p->d_ln1_out, T, aligned_embed);
2817 }
void swiglu_backward(const float *input, const float *d_output, float *d_input, int tokens, int dim)
void rope_backward_qk(const float *d_q_out, const float *d_k_out, float *d_q, float *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:497
void fc1_backward_kernel(const float *d_output, const float *fc1_input, const float *W_fc1, float *d_input, float *d_W_fc1, float *d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads)
Definition: mlp_kernels.c:167
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
void attention_backward_causal_head_major_gqa(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void ck_residual_add_backward(const float *d_out, float *d_a, float *d_b, int tokens, int aligned_embed_dim)
void ck_qkv_project_head_major_backward(const float *d_q, const float *d_k, const float *d_v, const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *d_input, float *d_wq, float *d_bq, float *d_wk, float *d_bk, float *d_wv, float *d_bv, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim, int num_threads)
void ck_attention_project_head_major_backward(const float *d_out, const float *attn_out, const float *wo, float *d_attn_out, float *d_wo, float *d_bo, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)

References CKLayerBackwardParams::aligned_context_window, CKLayerBackwardParams::aligned_embed_dim, CKLayerBackwardParams::aligned_head_dim, CKLayerBackwardParams::aligned_intermediate_dim, attention_backward_causal_head_major_gqa(), CKLayerBackwardParams::attn_out, CKLayerBackwardParams::bk, CKLayerBackwardParams::bq, CKLayerBackwardParams::bv, ck_add_inplace(), ck_attention_project_head_major_backward(), ck_qkv_project_head_major_backward(), ck_residual_add_backward(), CKLayerBackwardParams::d_attn_out, CKLayerBackwardParams::d_b1, CKLayerBackwardParams::d_b2, CKLayerBackwardParams::d_bk, CKLayerBackwardParams::d_bo, CKLayerBackwardParams::d_bq, CKLayerBackwardParams::d_bv, CKLayerBackwardParams::d_fc1_out, CKLayerBackwardParams::d_input, CKLayerBackwardParams::d_k, CKLayerBackwardParams::d_ln1_gamma, CKLayerBackwardParams::d_ln1_out, CKLayerBackwardParams::d_ln2_gamma, CKLayerBackwardParams::d_ln2_out, CKLayerBackwardParams::d_mlp_out, CKLayerBackwardParams::d_output, CKLayerBackwardParams::d_proj_tmp, CKLayerBackwardParams::d_q, CKLayerBackwardParams::d_residual1, CKLayerBackwardParams::d_scores, CKLayerBackwardParams::d_swiglu_out, CKLayerBackwardParams::d_v, CKLayerBackwardParams::d_w1, CKLayerBackwardParams::d_w2, CKLayerBackwardParams::d_wk, CKLayerBackwardParams::d_wo, CKLayerBackwardParams::d_wq, CKLayerBackwardParams::d_wv, CKLayerBackwardParams::embed_dim, fc1_backward_kernel(), CKLayerBackwardParams::fc1_out, fc2_backward_kernel(), CKLayerBackwardParams::head_dim, CKLayerBackwardParams::input, CKLayerBackwardParams::k, CKLayerBackwardParams::ln1_gamma, CKLayerBackwardParams::ln1_out, CKLayerBackwardParams::ln1_rstd, CKLayerBackwardParams::ln2_gamma, CKLayerBackwardParams::ln2_out, CKLayerBackwardParams::ln2_rstd, CKLayerBackwardParams::num_heads, CKLayerBackwardParams::num_kv_heads, CKLayerBackwardParams::q, CKLayerBackwardParams::residual1, rmsnorm_backward(), rope_backward_qk(), CKLayerBackwardParams::rope_cos, CKLayerBackwardParams::rope_pos_offset, CKLayerBackwardParams::rope_sin, CKLayerBackwardParams::scores, swiglu_backward(), CKLayerBackwardParams::swiglu_out, CKLayerBackwardParams::tokens, CKLayerBackwardParams::v, CKLayerBackwardParams::w1, CKLayerBackwardParams::w2, CKLayerBackwardParams::wk, CKLayerBackwardParams::wo, CKLayerBackwardParams::wq, and CKLayerBackwardParams::wv.

◆ ck_layer_forward_rmsnorm_swiglu()

void ck_layer_forward_rmsnorm_swiglu ( const CKLayerForwardParams p)

Definition at line 996 of file ckernel_orchestration.c.

997 {
998  if (!p) {
999  return;
1000  }
1001 
1003  p->ln1_gamma,
1004  p->ln1_out,
1005  p->ln1_rstd,
1006  p->tokens,
1007  p->embed_dim,
1008  p->aligned_embed_dim,
1009  p->eps);
1010 
1012  p->wq, p->bq,
1013  p->wk, p->bk,
1014  p->wv, p->bv,
1015  p->q, p->k, p->v,
1016  p->tokens,
1017  p->tokens,
1018  p->aligned_embed_dim,
1019  p->num_heads,
1020  p->num_kv_heads,
1021  p->aligned_head_dim);
1022 
1023  if (p->rope_cos && p->rope_sin) {
1024  rope_forward_qk(p->q,
1025  p->k,
1026  p->rope_cos,
1027  p->rope_sin,
1028  p->num_heads,
1029  p->num_kv_heads,
1030  p->tokens,
1031  p->head_dim,
1032  p->aligned_head_dim,
1033  p->rope_pos_offset);
1034  }
1035 
1036  if (p->scores) {
1038  p->k,
1039  p->v,
1040  p->scores,
1041  p->attn_out,
1042  p->num_heads,
1043  p->num_kv_heads,
1044  p->tokens,
1045  p->head_dim,
1046  p->aligned_head_dim,
1048  } else {
1050  p->k,
1051  p->v,
1052  p->attn_out,
1053  p->num_heads,
1054  p->num_kv_heads,
1055  p->tokens,
1056  p->head_dim,
1057  p->aligned_head_dim);
1058  }
1059 
1061  p->wo,
1062  p->bo,
1063  p->proj_tmp,
1064  p->proj_scratch,
1065  p->tokens,
1066  p->aligned_embed_dim,
1067  p->num_heads,
1068  p->aligned_head_dim);
1069 
1071  p->proj_tmp,
1072  p->residual1,
1073  p->tokens,
1074  p->aligned_embed_dim);
1075 
1077  p->ln2_gamma,
1078  p->ln2_out,
1079  p->ln2_rstd,
1080  p->tokens,
1081  p->embed_dim,
1082  p->aligned_embed_dim,
1083  p->eps);
1084 
1086  p->w1,
1087  p->b1,
1088  p->w2,
1089  p->b2,
1090  p->fc1_out,
1091  p->swiglu_out,
1092  p->mlp_out,
1093  p->tokens,
1094  p->aligned_embed_dim,
1096 
1098  p->mlp_out,
1099  p->output,
1100  p->tokens,
1101  p->aligned_embed_dim);
1102 }
void attention_forward_causal_head_major_gqa(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void attention_forward_causal_head_major_gqa_flash(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:448
void ck_attention_project_head_major(const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_mlp_swiglu_forward(const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_qkv_project_head_major(const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)

References CKLayerForwardParams::aligned_context_window, CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_flash(), CKLayerForwardParams::attn_out, CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_project_head_major(), ck_mlp_swiglu_forward(), ck_qkv_project_head_major(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::fc1_out, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln1_out, CKLayerForwardParams::ln1_rstd, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::proj_scratch, CKLayerForwardParams::proj_tmp, CKLayerForwardParams::q, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::scores, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::tokens, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.

◆ ck_layer_forward_rmsnorm_swiglu_decode()

void ck_layer_forward_rmsnorm_swiglu_decode ( const CKLayerForwardParams p,
int  token_index,
int  cache_capacity 
)

Definition at line 1289 of file ckernel_orchestration.c.

1292 {
1293  if (!p) {
1294  return;
1295  }
1296  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->ln1_out || !p->ln2_out ||
1297  !p->wq || !p->wk || !p->wv || !p->wo || !p->w1 || !p->w2 ||
1298  !p->k || !p->v ||
1299  !p->proj_tmp || !p->residual1 || !p->fc1_out || !p->swiglu_out || !p->mlp_out || !p->output) {
1300  return;
1301  }
1302  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
1303  return;
1304  }
1305  if (p->num_heads <= 0 || p->num_kv_heads <= 0 || p->aligned_head_dim <= 0) {
1306  return;
1307  }
1308 
1309  const int D = p->embed_dim;
1310  const int aligned_D = p->aligned_embed_dim;
1311  const int H = p->num_heads;
1312  const int H_kv = p->num_kv_heads;
1313  const int hd = p->head_dim;
1314  const int ad = p->aligned_head_dim;
1315  const int aligned_intermediate = p->aligned_intermediate_dim;
1316 
1317  /* Decode buffers are single-token; token_index only applies to KV cache. */
1318  const size_t token_slot = 0;
1319  const float *input_row = p->input + token_slot * (size_t)aligned_D;
1320  float *ln1_row = p->ln1_out + token_slot * (size_t)aligned_D;
1321  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
1322  float *proj_row = p->proj_tmp + token_slot * (size_t)aligned_D;
1323  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
1324  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
1325  float *out_row = p->output + token_slot * (size_t)aligned_D;
1326 
1327  float ln1_rstd_tmp = 0.0f;
1328  float ln2_rstd_tmp = 0.0f;
1329  float *ln1_rstd = p->ln1_rstd ? (p->ln1_rstd + token_slot) : &ln1_rstd_tmp;
1330  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
1331 
1332  // Scratch for a single token in head-major layout: [head, aligned_head_dim].
1333  size_t q_elems = (size_t)H * (size_t)ad;
1334  size_t kv_elems = (size_t)H_kv * (size_t)ad;
1335  float q_token[q_elems];
1336  float k_token[kv_elems];
1337  float v_token[kv_elems];
1338  float attn_token[q_elems];
1339 
1340  // LN1 / RMSNorm.
1341  rmsnorm_forward(input_row,
1342  p->ln1_gamma,
1343  ln1_row,
1344  ln1_rstd,
1345  /*tokens=*/1,
1346  D,
1347  aligned_D,
1348  p->eps);
1349 
1350  // Project Q/K/V for the new token.
1352  p->wq, p->bq,
1353  p->wk, p->bk,
1354  p->wv, p->bv,
1355  q_token, k_token, v_token,
1356  aligned_D,
1357  H,
1358  H_kv,
1359  ad);
1360 
1361  // RoPE for the new token at absolute position `p->rope_pos_offset`.
1362  if (p->rope_cos && p->rope_sin) {
1363  rope_forward_qk(q_token,
1364  k_token,
1365  p->rope_cos,
1366  p->rope_sin,
1367  H,
1368  H_kv,
1369  /*num_tokens=*/1,
1370  hd,
1371  ad,
1372  p->rope_pos_offset);
1373  }
1374 
1375  // Update KV cache (stores k/v for this token and clears padded lanes).
1376  kv_cache_write_head_major(k_token,
1377  v_token,
1378  p->k,
1379  p->v,
1380  H_kv,
1381  token_index,
1382  cache_capacity,
1383  hd,
1384  ad);
1385 
1386  // Decode attention for this token using the KV cache.
1388  p->k,
1389  p->v,
1390  attn_token,
1391  H,
1392  H_kv,
1393  /*kv_tokens=*/token_index + 1,
1394  cache_capacity,
1395  hd,
1396  ad);
1397 
1398  // Output projection (Wo) into token-major buffer (decode-specialized).
1400  p->wo,
1401  p->bo,
1402  proj_row,
1403  D,
1404  aligned_D,
1405  H,
1406  ad);
1407 
1408  // Residual + LN2 / RMSNorm.
1409  ck_residual_add_token_major(input_row,
1410  proj_row,
1411  residual_row,
1412  /*tokens=*/1,
1413  aligned_D);
1414 
1415  rmsnorm_forward(residual_row,
1416  p->ln2_gamma,
1417  ln2_row,
1418  ln2_rstd,
1419  /*tokens=*/1,
1420  D,
1421  aligned_D,
1422  p->eps);
1423 
1424  // MLP block for this token.
1425  int up_dim = 2 * aligned_intermediate;
1426  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
1427  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
1428 
1429  ck_mlp_swiglu_forward(ln2_row,
1430  p->w1,
1431  p->b1,
1432  p->w2,
1433  p->b2,
1434  fc1_row,
1435  swiglu_row,
1436  mlp_row,
1437  /*tokens=*/1,
1438  aligned_D,
1439  aligned_intermediate);
1440 
1441  // Final residual.
1442  ck_residual_add_token_major(residual_row,
1443  mlp_row,
1444  out_row,
1445  /*tokens=*/1,
1446  aligned_D);
1447 }
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
void ck_attention_flash_decode_wrapper(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
Wrapper to call TRUE flash attention from orchestration layer.
void ck_qkv_project_head_major_token(const float *input_row, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_attention_project_head_major_decode_token(const float *attn_token, const float *wo, const float *bo, float *out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)

References CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_flash_decode_wrapper(), ck_attention_project_head_major_decode_token(), ck_mlp_swiglu_forward(), ck_qkv_project_head_major_token(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::fc1_out, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, kv_cache_write_head_major(), CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln1_out, CKLayerForwardParams::ln1_rstd, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::proj_tmp, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.

◆ ck_layer_forward_rmsnorm_swiglu_decode_fused()

void ck_layer_forward_rmsnorm_swiglu_decode_fused ( const CKLayerForwardParams p,
int  token_index,
int  cache_capacity 
)

Definition at line 1449 of file ckernel_orchestration.c.

1452 {
1453  if (!p) {
1454  return;
1455  }
1456  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->ln1_out || !p->ln2_out ||
1457  !p->wq || !p->wk || !p->wv || !p->wo || !p->w1 || !p->w2 ||
1458  !p->k || !p->v || !p->swiglu_out ||
1459  !p->proj_tmp || !p->residual1 || !p->mlp_out || !p->output) {
1460  return;
1461  }
1462  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
1463  return;
1464  }
1465  if (p->num_heads <= 0 || p->num_kv_heads <= 0 || p->aligned_head_dim <= 0) {
1466  return;
1467  }
1468 
1469  const int D = p->embed_dim;
1470  const int aligned_D = p->aligned_embed_dim;
1471  const int H = p->num_heads;
1472  const int H_kv = p->num_kv_heads;
1473  const int hd = p->head_dim;
1474  const int ad = p->aligned_head_dim;
1475  const int aligned_intermediate = p->aligned_intermediate_dim;
1476 
1477  /* Decode buffers are single-token; token_index only applies to KV cache. */
1478  const size_t token_slot = 0;
1479  const float *input_row = p->input + token_slot * (size_t)aligned_D;
1480  float *ln1_row = p->ln1_out + token_slot * (size_t)aligned_D;
1481  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
1482  float *proj_row = p->proj_tmp + token_slot * (size_t)aligned_D;
1483  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
1484  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
1485  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
1486  float *out_row = p->output + token_slot * (size_t)aligned_D;
1487 
1488  float ln1_rstd_tmp = 0.0f;
1489  float ln2_rstd_tmp = 0.0f;
1490  float *ln1_rstd = p->ln1_rstd ? (p->ln1_rstd + token_slot) : &ln1_rstd_tmp;
1491  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
1492 
1493  // Scratch for a single token in head-major layout: [head, aligned_head_dim].
1494  size_t q_elems = (size_t)H * (size_t)ad;
1495  size_t kv_elems = (size_t)H_kv * (size_t)ad;
1496  float q_token[q_elems];
1497  float k_token[kv_elems];
1498  float v_token[kv_elems];
1499  float attn_token[q_elems];
1500 
1501  // LN1 / RMSNorm.
1502  rmsnorm_forward(input_row,
1503  p->ln1_gamma,
1504  ln1_row,
1505  ln1_rstd,
1506  /*tokens=*/1,
1507  D,
1508  aligned_D,
1509  p->eps);
1510 
1511  // Project Q/K/V for the new token.
1513  p->wq, p->bq,
1514  p->wk, p->bk,
1515  p->wv, p->bv,
1516  q_token, k_token, v_token,
1517  aligned_D,
1518  H,
1519  H_kv,
1520  ad);
1521 
1522  // RoPE for the new token at absolute position `p->rope_pos_offset`.
1523  if (p->rope_cos && p->rope_sin) {
1524  rope_forward_qk(q_token,
1525  k_token,
1526  p->rope_cos,
1527  p->rope_sin,
1528  H,
1529  H_kv,
1530  /*num_tokens=*/1,
1531  hd,
1532  ad,
1533  p->rope_pos_offset);
1534  }
1535 
1536  // Update KV cache (stores k/v for this token and clears padded lanes).
1537  kv_cache_write_head_major(k_token,
1538  v_token,
1539  p->k,
1540  p->v,
1541  H_kv,
1542  token_index,
1543  cache_capacity,
1544  hd,
1545  ad);
1546 
1547  // Decode attention for this token using the KV cache.
1549  p->k,
1550  p->v,
1551  attn_token,
1552  H,
1553  H_kv,
1554  /*kv_tokens=*/token_index + 1,
1555  cache_capacity,
1556  hd,
1557  ad);
1558 
1559  // Output projection (Wo) into token-major buffer (decode-specialized).
1561  p->wo,
1562  p->bo,
1563  proj_row,
1564  D,
1565  aligned_D,
1566  H,
1567  ad);
1568 
1569  // Residual + LN2 / RMSNorm.
1570  ck_residual_add_token_major(input_row,
1571  proj_row,
1572  residual_row,
1573  /*tokens=*/1,
1574  aligned_D);
1575 
1576  rmsnorm_forward(residual_row,
1577  p->ln2_gamma,
1578  ln2_row,
1579  ln2_rstd,
1580  /*tokens=*/1,
1581  D,
1582  aligned_D,
1583  p->eps);
1584 
1585  // MLP block for this token (fully fused - all 3 projections in one pass).
1586  // Eliminates DRAM round-trip for swiglu intermediate values.
1588  p->w1,
1589  p->b1,
1590  p->w2,
1591  p->b2,
1592  mlp_row,
1593  aligned_D,
1594  aligned_intermediate);
1595 
1596  // Final residual.
1597  ck_residual_add_token_major(residual_row,
1598  mlp_row,
1599  out_row,
1600  /*tokens=*/1,
1601  aligned_D);
1602 }
void ck_mlp_swiglu_forward_fully_fused_token(const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)

References CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_flash_decode_wrapper(), ck_attention_project_head_major_decode_token(), ck_mlp_swiglu_forward_fully_fused_token(), ck_qkv_project_head_major_token(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, kv_cache_write_head_major(), CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln1_out, CKLayerForwardParams::ln1_rstd, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::proj_tmp, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.

◆ ck_layer_forward_rmsnorm_swiglu_decode_fused_attn()

void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn ( const CKLayerForwardParams p,
int  token_index,
int  cache_capacity 
)

Definition at line 343 of file attention_decode_fused.c.

346 {
348  token_index,
349  cache_capacity,
350  /*fuse_mlp=*/0);
351 }
static void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(const CKLayerForwardParams *p, int token_index, int cache_capacity, int fuse_mlp)

References ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().

◆ ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp()

void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp ( const CKLayerForwardParams p,
int  token_index,
int  cache_capacity 
)

Definition at line 353 of file attention_decode_fused.c.

356 {
358  token_index,
359  cache_capacity,
360  /*fuse_mlp=*/1);
361 }

References ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().

◆ ck_layer_forward_rmsnorm_swiglu_decode_q4_k()

void ck_layer_forward_rmsnorm_swiglu_decode_q4_k ( const CKLayerForwardParamsQ4K p,
int  token_index,
int  cache_capacity 
)

Definition at line 2117 of file ckernel_orchestration.c.

2120 {
2121  if (!p) {
2122  return;
2123  }
2124  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->ln1_out || !p->ln2_out ||
2125  !p->wq || !p->wk || !p->wv || !p->wo || !p->w1 || !p->w2 ||
2126  !p->k || !p->v ||
2127  !p->proj_tmp || !p->residual1 || !p->fc1_out || !p->swiglu_out || !p->mlp_out || !p->output) {
2128  return;
2129  }
2130  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
2131  return;
2132  }
2133 
2134  const int D = p->embed_dim;
2135  const int aligned_D = p->aligned_embed_dim;
2136  const int H = p->num_heads;
2137  const int H_kv = p->num_kv_heads;
2138  const int hd = p->head_dim;
2139  const int ad = p->aligned_head_dim;
2140  const int aligned_intermediate = p->aligned_intermediate_dim;
2141  const int K_concat = H * ad;
2142 
2143  /* Decode buffers are single-token; token_index only applies to KV cache. */
2144  const size_t token_slot = 0;
2145  const float *input_row = p->input + token_slot * (size_t)aligned_D;
2146  float *ln1_row = p->ln1_out + token_slot * (size_t)aligned_D;
2147  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
2148  float *proj_row = p->proj_tmp + token_slot * (size_t)aligned_D;
2149  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
2150  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
2151  float *out_row = p->output + token_slot * (size_t)aligned_D;
2152 
2153  float ln1_rstd_tmp = 0.0f;
2154  float ln2_rstd_tmp = 0.0f;
2155  float *ln1_rstd = p->ln1_rstd ? (p->ln1_rstd + token_slot) : &ln1_rstd_tmp;
2156  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
2157 
2158  /* Scratch for a single token in head-major layout: [head, aligned_head_dim]. */
2159  size_t q_elems = (size_t)H * (size_t)ad;
2160  size_t kv_elems = (size_t)H_kv * (size_t)ad;
2161  float q_token[q_elems];
2162  float k_token[kv_elems];
2163  float v_token[kv_elems];
2164  float attn_token[q_elems];
2165 
2166  /* LN1 / RMSNorm. */
2167  ck_debug_check_buffer("input_row", input_row, aligned_D);
2168  rmsnorm_forward(input_row,
2169  p->ln1_gamma,
2170  ln1_row,
2171  ln1_rstd,
2172  /*tokens=*/1,
2173  D,
2174  aligned_D,
2175  p->eps);
2176  ck_debug_check_buffer("ln1_out (after rmsnorm)", ln1_row, aligned_D);
2177 
2179  if ((aligned_D % QK_K) == 0 && (aligned_intermediate % QK_K) == 0) {
2180  const int q8_blocks_embed = aligned_D / QK_K;
2181  const int q8_blocks_inter = aligned_intermediate / QK_K;
2182  const int q8_blocks_max = (q8_blocks_embed > q8_blocks_inter) ? q8_blocks_embed : q8_blocks_inter;
2183  block_q8_K q8_buf[q8_blocks_max];
2184 
2185  /* Project Q/K/V with Q8_K activations. */
2186  quantize_row_q8_k(ln1_row, q8_buf, aligned_D);
2187  ck_debug_check_q8k("q8_buf (after quantize)", q8_buf, q8_blocks_embed);
2188  ck_debug_check_q4k_weights("wq weights", p->wq, (aligned_D / QK_K) * (H * ad));
2190  p->wq, p->bq,
2191  p->wk, p->bk,
2192  p->wv, p->bv,
2193  q_token, k_token, v_token,
2194  aligned_D,
2195  H,
2196  H_kv,
2197  ad);
2198  ck_debug_check_buffer("q_token (after QKV proj)", q_token, (int)q_elems);
2199  ck_debug_check_buffer("k_token (after QKV proj)", k_token, (int)kv_elems);
2200  ck_debug_check_buffer("v_token (after QKV proj)", v_token, (int)kv_elems);
2201 
2202  /* RoPE for the new token at absolute position `p->rope_pos_offset`. */
2203  if (p->rope_cos && p->rope_sin) {
2204  rope_forward_qk(q_token,
2205  k_token,
2206  p->rope_cos,
2207  p->rope_sin,
2208  H,
2209  H_kv,
2210  /*num_tokens=*/1,
2211  hd,
2212  ad,
2213  p->rope_pos_offset);
2214  }
2215 
2216  /* Update KV cache. */
2217  kv_cache_write_head_major(k_token,
2218  v_token,
2219  p->k,
2220  p->v,
2221  H_kv,
2222  token_index,
2223  cache_capacity,
2224  hd,
2225  ad);
2226 
2227  /* Decode attention for this token using the KV cache. */
2229  p->k,
2230  p->v,
2231  attn_token,
2232  H,
2233  H_kv,
2234  /*kv_tokens=*/token_index + 1,
2235  cache_capacity,
2236  hd,
2237  ad);
2238  ck_debug_check_buffer("attn_token (after attention)", attn_token, (int)q_elems);
2239 
2240  /* Quantized output projection (Wo) with Q8_K activations. */
2241  quantize_row_q8_k(attn_token, q8_buf, aligned_D);
2242  gemm_nt_q4_k_q8_k(q8_buf,
2243  p->wo,
2244  p->bo,
2245  proj_row,
2246  /*M=*/1,
2247  aligned_D,
2248  /*K=*/K_concat);
2249  ck_debug_check_buffer("proj_row (after Wo proj)", proj_row, aligned_D);
2250 
2251  for (int j = D; j < aligned_D; ++j) {
2252  proj_row[j] = 0.0f;
2253  }
2254 
2255  /* Residual + LN2 / RMSNorm. */
2256  ck_residual_add_token_major(input_row,
2257  proj_row,
2258  residual_row,
2259  /*tokens=*/1,
2260  aligned_D);
2261 
2262  rmsnorm_forward(residual_row,
2263  p->ln2_gamma,
2264  ln2_row,
2265  ln2_rstd,
2266  /*tokens=*/1,
2267  D,
2268  aligned_D,
2269  p->eps);
2270 
2271  /* MLP block for this token (Q8_K activations). */
2272  int up_dim = 2 * aligned_intermediate;
2273  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
2274  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
2275 
2277  p->w1,
2278  p->b1,
2279  p->w2,
2280  p->b2,
2281  fc1_row,
2282  swiglu_row,
2283  mlp_row,
2284  aligned_D,
2285  aligned_intermediate);
2286  ck_debug_check_buffer("mlp_row (after MLP)", mlp_row, aligned_D);
2287 
2288  /* Final residual. */
2289  ck_residual_add_token_major(residual_row,
2290  mlp_row,
2291  out_row,
2292  /*tokens=*/1,
2293  aligned_D);
2294  ck_debug_check_buffer("out_row (final output)", out_row, aligned_D);
2295  return;
2296  }
2297  }
2298 
2299  /* Project Q/K/V for the new token (Q4_K weights). */
2301  p->wq, p->bq,
2302  p->wk, p->bk,
2303  p->wv, p->bv,
2304  q_token, k_token, v_token,
2305  aligned_D,
2306  H,
2307  H_kv,
2308  ad);
2309 
2310  /* RoPE for the new token at absolute position `p->rope_pos_offset`. */
2311  if (p->rope_cos && p->rope_sin) {
2312  rope_forward_qk(q_token,
2313  k_token,
2314  p->rope_cos,
2315  p->rope_sin,
2316  H,
2317  H_kv,
2318  /*num_tokens=*/1,
2319  hd,
2320  ad,
2321  p->rope_pos_offset);
2322  }
2323 
2324  /* Update KV cache. */
2325  kv_cache_write_head_major(k_token,
2326  v_token,
2327  p->k,
2328  p->v,
2329  H_kv,
2330  token_index,
2331  cache_capacity,
2332  hd,
2333  ad);
2334 
2335  /* Decode attention for this token using the KV cache. */
2337  p->k,
2338  p->v,
2339  attn_token,
2340  H,
2341  H_kv,
2342  /*kv_tokens=*/token_index + 1,
2343  cache_capacity,
2344  hd,
2345  ad);
2346 
2347  /* Quantized output projection: Wo is stored as a flattened Q4_K matrix. */
2348  gemm_nt_q4_k(attn_token,
2349  p->wo,
2350  p->bo,
2351  proj_row,
2352  /*M=*/1,
2353  aligned_D,
2354  /*K=*/K_concat);
2355 
2356  for (int j = D; j < aligned_D; ++j) {
2357  proj_row[j] = 0.0f;
2358  }
2359 
2360  /* Residual + LN2 / RMSNorm. */
2361  ck_residual_add_token_major(input_row,
2362  proj_row,
2363  residual_row,
2364  /*tokens=*/1,
2365  aligned_D);
2366 
2367  rmsnorm_forward(residual_row,
2368  p->ln2_gamma,
2369  ln2_row,
2370  ln2_rstd,
2371  /*tokens=*/1,
2372  D,
2373  aligned_D,
2374  p->eps);
2375 
2376  /* MLP block for this token. */
2377  int up_dim = 2 * aligned_intermediate;
2378  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
2379  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
2380 
2382  p->w1,
2383  p->b1,
2384  p->w2,
2385  p->b2,
2386  fc1_row,
2387  swiglu_row,
2388  mlp_row,
2389  /*tokens=*/1,
2390  aligned_D,
2391  aligned_intermediate);
2392 
2393  /* Final residual. */
2394  ck_residual_add_token_major(residual_row,
2395  mlp_row,
2396  out_row,
2397  /*tokens=*/1,
2398  aligned_D);
2399 }
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void quantize_row_q8_k(const float *x, void *y, int k)
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!
static void ck_qkv_project_head_major_token_q4_k_q8_k(const block_q8_K *input_q8, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static int ck_q8k_activations_enabled(void)
static void ck_mlp_swiglu_forward_q4_k_q8_k(const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_debug_check_q8k(const char *stage, const void *q8_buf, int num_blocks)
static void ck_mlp_swiglu_forward_q4_k(const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_debug_check_q4k_weights(const char *stage, const void *q4_buf, int num_blocks)
static void ck_debug_check_buffer(const char *stage, const float *buf, int size)
static void ck_qkv_project_head_major_token_q4_k(const float *input_row, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
#define QK_K

References CKLayerForwardParamsQ4K::aligned_embed_dim, CKLayerForwardParamsQ4K::aligned_head_dim, CKLayerForwardParamsQ4K::aligned_intermediate_dim, attention_forward_decode_head_major_gqa_regular(), CKLayerForwardParamsQ4K::b1, CKLayerForwardParamsQ4K::b2, CKLayerForwardParamsQ4K::bk, CKLayerForwardParamsQ4K::bo, CKLayerForwardParamsQ4K::bq, CKLayerForwardParamsQ4K::bv, ck_attention_flash_decode_wrapper(), ck_debug_check_buffer(), ck_debug_check_q4k_weights(), ck_debug_check_q8k(), ck_mlp_swiglu_forward_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_q8k_activations_enabled(), ck_qkv_project_head_major_token_q4_k(), ck_qkv_project_head_major_token_q4_k_q8_k(), ck_residual_add_token_major(), CKLayerForwardParamsQ4K::embed_dim, CKLayerForwardParamsQ4K::eps, CKLayerForwardParamsQ4K::fc1_out, gemm_nt_q4_k(), gemm_nt_q4_k_q8_k(), CKLayerForwardParamsQ4K::head_dim, CKLayerForwardParamsQ4K::input, CKLayerForwardParamsQ4K::k, kv_cache_write_head_major(), CKLayerForwardParamsQ4K::ln1_gamma, CKLayerForwardParamsQ4K::ln1_out, CKLayerForwardParamsQ4K::ln1_rstd, CKLayerForwardParamsQ4K::ln2_gamma, CKLayerForwardParamsQ4K::ln2_out, CKLayerForwardParamsQ4K::ln2_rstd, CKLayerForwardParamsQ4K::mlp_out, CKLayerForwardParamsQ4K::num_heads, CKLayerForwardParamsQ4K::num_kv_heads, CKLayerForwardParamsQ4K::output, CKLayerForwardParamsQ4K::proj_tmp, QK_K, quantize_row_q8_k(), CKLayerForwardParamsQ4K::residual1, rmsnorm_forward(), CKLayerForwardParamsQ4K::rope_cos, rope_forward_qk(), CKLayerForwardParamsQ4K::rope_pos_offset, CKLayerForwardParamsQ4K::rope_sin, CKLayerForwardParamsQ4K::swiglu_out, CKLayerForwardParamsQ4K::v, CKLayerForwardParamsQ4K::w1, CKLayerForwardParamsQ4K::w2, CKLayerForwardParamsQ4K::wk, CKLayerForwardParamsQ4K::wo, CKLayerForwardParamsQ4K::wq, and CKLayerForwardParamsQ4K::wv.

◆ ck_layer_forward_rmsnorm_swiglu_decode_quant()

void ck_layer_forward_rmsnorm_swiglu_decode_quant ( const CKLayerForwardParamsQ4K p,
int  token_index,
int  cache_capacity 
)

Definition at line 2512 of file ckernel_orchestration.c.

2515 {
2516  if (!p) {
2517  return;
2518  }
2519  if (!p->input || !p->ln1_gamma || !p->ln2_gamma || !p->ln1_out || !p->ln2_out ||
2520  !p->wq || !p->wk || !p->wv || !p->wo || !p->w1 || !p->w2 ||
2521  !p->k || !p->v ||
2522  !p->proj_tmp || !p->proj_scratch || !p->residual1 || !p->fc1_out || !p->swiglu_out || !p->mlp_out || !p->output) {
2523  return;
2524  }
2525  if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
2526  return;
2527  }
2528 
2529  const int D = p->embed_dim;
2530  const int aligned_D = p->aligned_embed_dim;
2531  const int H = p->num_heads;
2532  const int H_kv = p->num_kv_heads;
2533  const int hd = p->head_dim;
2534  const int ad = p->aligned_head_dim;
2535  const int aligned_intermediate = p->aligned_intermediate_dim;
2536  const int K_concat = H * ad;
2537 
2538  /* Decode buffers are single-token; token_index only applies to KV cache. */
2539  const size_t token_slot = 0;
2540  const float *input_row = p->input + token_slot * (size_t)aligned_D;
2541  float *ln1_row = p->ln1_out + token_slot * (size_t)aligned_D;
2542  float *ln2_row = p->ln2_out + token_slot * (size_t)aligned_D;
2543  float *proj_row = p->proj_tmp + token_slot * (size_t)aligned_D;
2544  float *residual_row = p->residual1 + token_slot * (size_t)aligned_D;
2545  float *mlp_row = p->mlp_out + token_slot * (size_t)aligned_D;
2546  float *out_row = p->output + token_slot * (size_t)aligned_D;
2547 
2548  float ln1_rstd_tmp = 0.0f;
2549  float ln2_rstd_tmp = 0.0f;
2550  float *ln1_rstd = p->ln1_rstd ? (p->ln1_rstd + token_slot) : &ln1_rstd_tmp;
2551  float *ln2_rstd = p->ln2_rstd ? (p->ln2_rstd + token_slot) : &ln2_rstd_tmp;
2552 
2553  size_t q_elems = (size_t)H * (size_t)ad;
2554  size_t kv_elems = (size_t)H_kv * (size_t)ad;
2555  float q_token[q_elems];
2556  float k_token[kv_elems];
2557  float v_token[kv_elems];
2558  float attn_token[q_elems];
2559 
2560  rmsnorm_forward(input_row,
2561  p->ln1_gamma,
2562  ln1_row,
2563  ln1_rstd,
2564  /*tokens=*/1,
2565  D,
2566  aligned_D,
2567  p->eps);
2568 
2570  p->wq, p->bq, p->wq_dtype,
2571  p->wk, p->bk, p->wk_dtype,
2572  p->wv, p->bv, p->wv_dtype,
2573  q_token, k_token, v_token,
2574  aligned_D,
2575  H,
2576  H_kv,
2577  ad);
2578 
2579  if (p->rope_cos && p->rope_sin) {
2580  rope_forward_qk(q_token,
2581  k_token,
2582  p->rope_cos,
2583  p->rope_sin,
2584  H,
2585  H_kv,
2586  /*num_tokens=*/1,
2587  hd,
2588  ad,
2589  p->rope_pos_offset);
2590  }
2591 
2592  kv_cache_write_head_major(k_token,
2593  v_token,
2594  p->k,
2595  p->v,
2596  H_kv,
2597  token_index,
2598  cache_capacity,
2599  hd,
2600  ad);
2601 
2603  p->k,
2604  p->v,
2605  attn_token,
2606  H,
2607  H_kv,
2608  /*kv_tokens=*/token_index + 1,
2609  cache_capacity,
2610  hd,
2611  ad);
2612 
2613  if (p->wo_dtype == CK_DT_FP32) {
2615  (const float *)p->wo,
2616  p->bo,
2617  proj_row,
2618  D,
2619  aligned_D,
2620  H,
2621  ad);
2622  } else {
2623  /* Quantized attention output projection - handle all quant types */
2624  ck_gemm_nt_quant(attn_token,
2625  p->wo,
2626  p->bo,
2627  proj_row,
2628  /*M=*/1,
2629  aligned_D,
2630  /*K=*/K_concat,
2631  p->wo_dtype);
2632  for (int j = D; j < aligned_D; ++j) {
2633  proj_row[j] = 0.0f;
2634  }
2635  }
2636 
2637  ck_residual_add_token_major(input_row,
2638  proj_row,
2639  residual_row,
2640  /*tokens=*/1,
2641  aligned_D);
2642 
2643  rmsnorm_forward(residual_row,
2644  p->ln2_gamma,
2645  ln2_row,
2646  ln2_rstd,
2647  /*tokens=*/1,
2648  D,
2649  aligned_D,
2650  p->eps);
2651 
2652  int up_dim = 2 * aligned_intermediate;
2653  float *fc1_row = p->fc1_out + token_slot * (size_t)up_dim;
2654  float *swiglu_row = p->swiglu_out + token_slot * (size_t)aligned_intermediate;
2655 
2657  p->w1,
2658  p->b1,
2659  p->w1_dtype,
2660  p->w2,
2661  p->b2,
2662  p->w2_dtype,
2663  fc1_row,
2664  swiglu_row,
2665  mlp_row,
2666  /*tokens=*/1,
2667  aligned_D,
2668  aligned_intermediate);
2669 
2670  ck_residual_add_token_major(residual_row,
2671  mlp_row,
2672  out_row,
2673  /*tokens=*/1,
2674  aligned_D);
2675 }
static void ck_mlp_swiglu_forward_quant(const float *input, const void *w1, const float *b1, CKDataType w1_dtype, const void *w2, const float *b2, CKDataType w2_dtype, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_gemm_nt_quant(const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype)
static void ck_qkv_project_head_major_token_quant(const float *input_row, const void *wq, const float *bq, CKDataType wq_dtype, const void *wk, const float *bk, CKDataType wk_dtype, const void *wv, const float *bv, CKDataType wv_dtype, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)

References CKLayerForwardParamsQ4K::aligned_embed_dim, CKLayerForwardParamsQ4K::aligned_head_dim, CKLayerForwardParamsQ4K::aligned_intermediate_dim, CKLayerForwardParamsQ4K::b1, CKLayerForwardParamsQ4K::b2, CKLayerForwardParamsQ4K::bk, CKLayerForwardParamsQ4K::bo, CKLayerForwardParamsQ4K::bq, CKLayerForwardParamsQ4K::bv, ck_attention_flash_decode_wrapper(), ck_attention_project_head_major_decode_token(), CK_DT_FP32, ck_gemm_nt_quant(), ck_mlp_swiglu_forward_quant(), ck_qkv_project_head_major_token_quant(), ck_residual_add_token_major(), CKLayerForwardParamsQ4K::embed_dim, CKLayerForwardParamsQ4K::eps, CKLayerForwardParamsQ4K::fc1_out, CKLayerForwardParamsQ4K::head_dim, CKLayerForwardParamsQ4K::input, CKLayerForwardParamsQ4K::k, kv_cache_write_head_major(), CKLayerForwardParamsQ4K::ln1_gamma, CKLayerForwardParamsQ4K::ln1_out, CKLayerForwardParamsQ4K::ln1_rstd, CKLayerForwardParamsQ4K::ln2_gamma, CKLayerForwardParamsQ4K::ln2_out, CKLayerForwardParamsQ4K::ln2_rstd, CKLayerForwardParamsQ4K::mlp_out, CKLayerForwardParamsQ4K::num_heads, CKLayerForwardParamsQ4K::num_kv_heads, CKLayerForwardParamsQ4K::output, CKLayerForwardParamsQ4K::proj_scratch, CKLayerForwardParamsQ4K::proj_tmp, CKLayerForwardParamsQ4K::residual1, rmsnorm_forward(), CKLayerForwardParamsQ4K::rope_cos, rope_forward_qk(), CKLayerForwardParamsQ4K::rope_pos_offset, CKLayerForwardParamsQ4K::rope_sin, CKLayerForwardParamsQ4K::swiglu_out, CKLayerForwardParamsQ4K::v, CKLayerForwardParamsQ4K::w1, CKLayerForwardParamsQ4K::w1_dtype, CKLayerForwardParamsQ4K::w2, CKLayerForwardParamsQ4K::w2_dtype, CKLayerForwardParamsQ4K::wk, CKLayerForwardParamsQ4K::wk_dtype, CKLayerForwardParamsQ4K::wo, CKLayerForwardParamsQ4K::wo_dtype, CKLayerForwardParamsQ4K::wq, CKLayerForwardParamsQ4K::wq_dtype, CKLayerForwardParamsQ4K::wv, and CKLayerForwardParamsQ4K::wv_dtype.

◆ ck_layer_forward_rmsnorm_swiglu_q4_k()

void ck_layer_forward_rmsnorm_swiglu_q4_k ( const CKLayerForwardParamsQ4K p)

Definition at line 1910 of file ckernel_orchestration.c.

1911 {
1912  if (!p) {
1913  return;
1914  }
1915 
1916  const int aligned_D = p->aligned_embed_dim;
1917  const int aligned_intermediate = p->aligned_intermediate_dim;
1918 
1920  p->ln1_gamma,
1921  p->ln1_out,
1922  p->ln1_rstd,
1923  p->tokens,
1924  p->embed_dim,
1925  aligned_D,
1926  p->eps);
1927 
1929  if ((aligned_D % QK_K) == 0 && (aligned_intermediate % QK_K) == 0) {
1931  p->wq, p->bq,
1932  p->wk, p->bk,
1933  p->wv, p->bv,
1934  p->q, p->k, p->v,
1935  p->tokens,
1936  p->tokens,
1937  aligned_D,
1938  p->num_heads,
1939  p->num_kv_heads,
1940  p->aligned_head_dim);
1941 
1942  if (p->rope_cos && p->rope_sin) {
1943  rope_forward_qk(p->q,
1944  p->k,
1945  p->rope_cos,
1946  p->rope_sin,
1947  p->num_heads,
1948  p->num_kv_heads,
1949  p->tokens,
1950  p->head_dim,
1951  p->aligned_head_dim,
1952  p->rope_pos_offset);
1953  }
1954 
1955  if (p->scores) {
1957  p->k,
1958  p->v,
1959  p->scores,
1960  p->attn_out,
1961  p->num_heads,
1962  p->num_kv_heads,
1963  p->tokens,
1964  p->head_dim,
1965  p->aligned_head_dim,
1967  } else {
1969  p->k,
1970  p->v,
1971  p->attn_out,
1972  p->num_heads,
1973  p->num_kv_heads,
1974  p->tokens,
1975  p->head_dim,
1976  p->aligned_head_dim);
1977  }
1978 
1980  p->wo,
1981  p->bo,
1982  p->proj_tmp,
1983  p->tokens,
1984  aligned_D,
1985  p->num_heads,
1986  p->aligned_head_dim);
1987 
1989  p->proj_tmp,
1990  p->residual1,
1991  p->tokens,
1992  aligned_D);
1993 
1995  p->ln2_gamma,
1996  p->ln2_out,
1997  p->ln2_rstd,
1998  p->tokens,
1999  p->embed_dim,
2000  aligned_D,
2001  p->eps);
2002 
2004  p->w1,
2005  p->b1,
2006  p->w2,
2007  p->b2,
2008  p->fc1_out,
2009  p->swiglu_out,
2010  p->mlp_out,
2011  p->tokens,
2012  aligned_D,
2013  aligned_intermediate);
2014 
2016  p->mlp_out,
2017  p->output,
2018  p->tokens,
2019  aligned_D);
2020  return;
2021  }
2022  }
2023 
2025  p->wq, p->bq,
2026  p->wk, p->bk,
2027  p->wv, p->bv,
2028  p->q, p->k, p->v,
2029  p->tokens,
2030  p->tokens,
2031  aligned_D,
2032  p->num_heads,
2033  p->num_kv_heads,
2034  p->aligned_head_dim);
2035 
2036  if (p->rope_cos && p->rope_sin) {
2037  rope_forward_qk(p->q,
2038  p->k,
2039  p->rope_cos,
2040  p->rope_sin,
2041  p->num_heads,
2042  p->num_kv_heads,
2043  p->tokens,
2044  p->head_dim,
2045  p->aligned_head_dim,
2046  p->rope_pos_offset);
2047  }
2048 
2049  if (p->scores) {
2051  p->k,
2052  p->v,
2053  p->scores,
2054  p->attn_out,
2055  p->num_heads,
2056  p->num_kv_heads,
2057  p->tokens,
2058  p->head_dim,
2059  p->aligned_head_dim,
2061  } else {
2063  p->k,
2064  p->v,
2065  p->attn_out,
2066  p->num_heads,
2067  p->num_kv_heads,
2068  p->tokens,
2069  p->head_dim,
2070  p->aligned_head_dim);
2071  }
2072 
2074  p->wo,
2075  p->bo,
2076  p->proj_tmp,
2077  p->proj_scratch,
2078  p->tokens,
2079  p->aligned_embed_dim,
2080  p->num_heads,
2081  p->aligned_head_dim);
2082 
2084  p->proj_tmp,
2085  p->residual1,
2086  p->tokens,
2087  p->aligned_embed_dim);
2088 
2090  p->ln2_gamma,
2091  p->ln2_out,
2092  p->ln2_rstd,
2093  p->tokens,
2094  p->embed_dim,
2095  p->aligned_embed_dim,
2096  p->eps);
2097 
2099  p->w1,
2100  p->b1,
2101  p->w2,
2102  p->b2,
2103  p->fc1_out,
2104  p->swiglu_out,
2105  p->mlp_out,
2106  p->tokens,
2107  p->aligned_embed_dim,
2109 
2111  p->mlp_out,
2112  p->output,
2113  p->tokens,
2114  p->aligned_embed_dim);
2115 }
static void ck_attention_project_head_major_q4_k(const float *attn_out, const void *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void ck_qkv_project_head_major_q4_k_q8_k(const float *input, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_attention_project_head_major_q4_k_q8_k(const float *attn_out, const void *wo, const float *bo, float *out, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void ck_qkv_project_head_major_q4_k(const float *input, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_mlp_swiglu_forward_q4_k_q8_k_prefill(const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)

References CKLayerForwardParamsQ4K::aligned_context_window, CKLayerForwardParamsQ4K::aligned_embed_dim, CKLayerForwardParamsQ4K::aligned_head_dim, CKLayerForwardParamsQ4K::aligned_intermediate_dim, attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_flash(), CKLayerForwardParamsQ4K::attn_out, CKLayerForwardParamsQ4K::b1, CKLayerForwardParamsQ4K::b2, CKLayerForwardParamsQ4K::bk, CKLayerForwardParamsQ4K::bo, CKLayerForwardParamsQ4K::bq, CKLayerForwardParamsQ4K::bv, ck_attention_project_head_major_q4_k(), ck_attention_project_head_major_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_q8k_activations_enabled(), ck_qkv_project_head_major_q4_k(), ck_qkv_project_head_major_q4_k_q8_k(), ck_residual_add_token_major(), CKLayerForwardParamsQ4K::embed_dim, CKLayerForwardParamsQ4K::eps, CKLayerForwardParamsQ4K::fc1_out, CKLayerForwardParamsQ4K::head_dim, CKLayerForwardParamsQ4K::input, CKLayerForwardParamsQ4K::k, CKLayerForwardParamsQ4K::ln1_gamma, CKLayerForwardParamsQ4K::ln1_out, CKLayerForwardParamsQ4K::ln1_rstd, CKLayerForwardParamsQ4K::ln2_gamma, CKLayerForwardParamsQ4K::ln2_out, CKLayerForwardParamsQ4K::ln2_rstd, CKLayerForwardParamsQ4K::mlp_out, CKLayerForwardParamsQ4K::num_heads, CKLayerForwardParamsQ4K::num_kv_heads, CKLayerForwardParamsQ4K::output, CKLayerForwardParamsQ4K::proj_scratch, CKLayerForwardParamsQ4K::proj_tmp, CKLayerForwardParamsQ4K::q, QK_K, CKLayerForwardParamsQ4K::residual1, rmsnorm_forward(), CKLayerForwardParamsQ4K::rope_cos, rope_forward_qk(), CKLayerForwardParamsQ4K::rope_pos_offset, CKLayerForwardParamsQ4K::rope_sin, CKLayerForwardParamsQ4K::scores, CKLayerForwardParamsQ4K::swiglu_out, CKLayerForwardParamsQ4K::tokens, CKLayerForwardParamsQ4K::v, CKLayerForwardParamsQ4K::w1, CKLayerForwardParamsQ4K::w2, CKLayerForwardParamsQ4K::wk, CKLayerForwardParamsQ4K::wo, CKLayerForwardParamsQ4K::wq, and CKLayerForwardParamsQ4K::wv.

◆ ck_layer_forward_rmsnorm_swiglu_quant()

void ck_layer_forward_rmsnorm_swiglu_quant ( const CKLayerForwardParamsQ4K p)

Definition at line 2401 of file ckernel_orchestration.c.

2402 {
2403  if (!p) {
2404  return;
2405  }
2406 
2408  p->ln1_gamma,
2409  p->ln1_out,
2410  p->ln1_rstd,
2411  p->tokens,
2412  p->embed_dim,
2413  p->aligned_embed_dim,
2414  p->eps);
2415 
2417  p->wq, p->bq, p->wq_dtype,
2418  p->wk, p->bk, p->wk_dtype,
2419  p->wv, p->bv, p->wv_dtype,
2420  p->q, p->k, p->v,
2421  p->tokens,
2422  p->tokens,
2423  p->aligned_embed_dim,
2424  p->num_heads,
2425  p->num_kv_heads,
2426  p->aligned_head_dim);
2427 
2428  if (p->rope_cos && p->rope_sin) {
2429  rope_forward_qk(p->q,
2430  p->k,
2431  p->rope_cos,
2432  p->rope_sin,
2433  p->num_heads,
2434  p->num_kv_heads,
2435  p->tokens,
2436  p->head_dim,
2437  p->aligned_head_dim,
2438  p->rope_pos_offset);
2439  }
2440 
2441  if (p->scores) {
2443  p->k,
2444  p->v,
2445  p->scores,
2446  p->attn_out,
2447  p->num_heads,
2448  p->num_kv_heads,
2449  p->tokens,
2450  p->head_dim,
2451  p->aligned_head_dim,
2453  } else {
2455  p->k,
2456  p->v,
2457  p->attn_out,
2458  p->num_heads,
2459  p->num_kv_heads,
2460  p->tokens,
2461  p->head_dim,
2462  p->aligned_head_dim);
2463  }
2464 
2466  p->wo,
2467  p->bo,
2468  p->proj_tmp,
2469  p->proj_scratch,
2470  p->tokens,
2471  p->aligned_embed_dim,
2472  p->num_heads,
2473  p->aligned_head_dim,
2474  p->wo_dtype);
2475 
2477  p->proj_tmp,
2478  p->residual1,
2479  p->tokens,
2480  p->aligned_embed_dim);
2481 
2483  p->ln2_gamma,
2484  p->ln2_out,
2485  p->ln2_rstd,
2486  p->tokens,
2487  p->embed_dim,
2488  p->aligned_embed_dim,
2489  p->eps);
2490 
2492  p->w1,
2493  p->b1,
2494  p->w1_dtype,
2495  p->w2,
2496  p->b2,
2497  p->w2_dtype,
2498  p->fc1_out,
2499  p->swiglu_out,
2500  p->mlp_out,
2501  p->tokens,
2502  p->aligned_embed_dim,
2504 
2506  p->mlp_out,
2507  p->output,
2508  p->tokens,
2509  p->aligned_embed_dim);
2510 }
static void ck_qkv_project_head_major_quant(const float *input, const void *wq, const float *bq, CKDataType wq_dtype, const void *wk, const float *bk, CKDataType wk_dtype, const void *wv, const float *bv, CKDataType wv_dtype, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_attention_project_head_major_quant(const float *attn_out, const void *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, CKDataType wo_dtype)

References CKLayerForwardParamsQ4K::aligned_context_window, CKLayerForwardParamsQ4K::aligned_embed_dim, CKLayerForwardParamsQ4K::aligned_head_dim, CKLayerForwardParamsQ4K::aligned_intermediate_dim, attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_flash(), CKLayerForwardParamsQ4K::attn_out, CKLayerForwardParamsQ4K::b1, CKLayerForwardParamsQ4K::b2, CKLayerForwardParamsQ4K::bk, CKLayerForwardParamsQ4K::bo, CKLayerForwardParamsQ4K::bq, CKLayerForwardParamsQ4K::bv, ck_attention_project_head_major_quant(), ck_mlp_swiglu_forward_quant(), ck_qkv_project_head_major_quant(), ck_residual_add_token_major(), CKLayerForwardParamsQ4K::embed_dim, CKLayerForwardParamsQ4K::eps, CKLayerForwardParamsQ4K::fc1_out, CKLayerForwardParamsQ4K::head_dim, CKLayerForwardParamsQ4K::input, CKLayerForwardParamsQ4K::k, CKLayerForwardParamsQ4K::ln1_gamma, CKLayerForwardParamsQ4K::ln1_out, CKLayerForwardParamsQ4K::ln1_rstd, CKLayerForwardParamsQ4K::ln2_gamma, CKLayerForwardParamsQ4K::ln2_out, CKLayerForwardParamsQ4K::ln2_rstd, CKLayerForwardParamsQ4K::mlp_out, CKLayerForwardParamsQ4K::num_heads, CKLayerForwardParamsQ4K::num_kv_heads, CKLayerForwardParamsQ4K::output, CKLayerForwardParamsQ4K::proj_scratch, CKLayerForwardParamsQ4K::proj_tmp, CKLayerForwardParamsQ4K::q, CKLayerForwardParamsQ4K::residual1, rmsnorm_forward(), CKLayerForwardParamsQ4K::rope_cos, rope_forward_qk(), CKLayerForwardParamsQ4K::rope_pos_offset, CKLayerForwardParamsQ4K::rope_sin, CKLayerForwardParamsQ4K::scores, CKLayerForwardParamsQ4K::swiglu_out, CKLayerForwardParamsQ4K::tokens, CKLayerForwardParamsQ4K::v, CKLayerForwardParamsQ4K::w1, CKLayerForwardParamsQ4K::w1_dtype, CKLayerForwardParamsQ4K::w2, CKLayerForwardParamsQ4K::w2_dtype, CKLayerForwardParamsQ4K::wk, CKLayerForwardParamsQ4K::wk_dtype, CKLayerForwardParamsQ4K::wo, CKLayerForwardParamsQ4K::wo_dtype, CKLayerForwardParamsQ4K::wq, CKLayerForwardParamsQ4K::wq_dtype, CKLayerForwardParamsQ4K::wv, and CKLayerForwardParamsQ4K::wv_dtype.

◆ ck_layer_forward_rmsnorm_swiglu_ref()

void ck_layer_forward_rmsnorm_swiglu_ref ( const CKLayerForwardParams p)

Definition at line 1104 of file ckernel_orchestration.c.

1105 {
1106  if (!p) {
1107  return;
1108  }
1109 
1111  p->ln1_gamma,
1112  p->ln1_out,
1113  p->ln1_rstd,
1114  p->tokens,
1115  p->embed_dim,
1116  p->aligned_embed_dim,
1117  p->eps);
1118 
1120  p->wq, p->bq,
1121  p->wk, p->bk,
1122  p->wv, p->bv,
1123  p->q, p->k, p->v,
1124  p->tokens,
1125  p->tokens,
1126  p->aligned_embed_dim,
1127  p->num_heads,
1128  p->num_kv_heads,
1129  p->aligned_head_dim);
1130 
1131  if (p->rope_cos && p->rope_sin) {
1132  rope_forward_qk(p->q,
1133  p->k,
1134  p->rope_cos,
1135  p->rope_sin,
1136  p->num_heads,
1137  p->num_kv_heads,
1138  p->tokens,
1139  p->head_dim,
1140  p->aligned_head_dim,
1141  p->rope_pos_offset);
1142  }
1143 
1144  if (p->scores) {
1146  p->k,
1147  p->v,
1148  p->scores,
1149  p->attn_out,
1150  p->num_heads,
1151  p->num_kv_heads,
1152  p->tokens,
1153  p->head_dim,
1154  p->aligned_head_dim,
1156  } else {
1158  p->k,
1159  p->v,
1160  p->attn_out,
1161  p->num_heads,
1162  p->num_kv_heads,
1163  p->tokens,
1164  p->head_dim,
1165  p->aligned_head_dim);
1166  }
1167 
1169  p->wo,
1170  p->bo,
1171  p->proj_tmp,
1172  p->proj_scratch,
1173  p->tokens,
1174  p->aligned_embed_dim,
1175  p->num_heads,
1176  p->aligned_head_dim);
1177 
1179  p->proj_tmp,
1180  p->residual1,
1181  p->tokens,
1182  p->aligned_embed_dim);
1183 
1185  p->ln2_gamma,
1186  p->ln2_out,
1187  p->ln2_rstd,
1188  p->tokens,
1189  p->embed_dim,
1190  p->aligned_embed_dim,
1191  p->eps);
1192 
1194  p->w1,
1195  p->b1,
1196  p->w2,
1197  p->b2,
1198  p->fc1_out,
1199  p->swiglu_out,
1200  p->mlp_out,
1201  p->tokens,
1202  p->aligned_embed_dim,
1204 
1206  p->mlp_out,
1207  p->output,
1208  p->tokens,
1209  p->aligned_embed_dim);
1210 }
static void ck_attention_project_head_major_ref(const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void ck_qkv_project_head_major_ref(const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_mlp_swiglu_forward_ref(const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)

References CKLayerForwardParams::aligned_context_window, CKLayerForwardParams::aligned_embed_dim, CKLayerForwardParams::aligned_head_dim, CKLayerForwardParams::aligned_intermediate_dim, attention_forward_causal_head_major_gqa(), attention_forward_causal_head_major_gqa_flash(), CKLayerForwardParams::attn_out, CKLayerForwardParams::b1, CKLayerForwardParams::b2, CKLayerForwardParams::bk, CKLayerForwardParams::bo, CKLayerForwardParams::bq, CKLayerForwardParams::bv, ck_attention_project_head_major_ref(), ck_mlp_swiglu_forward_ref(), ck_qkv_project_head_major_ref(), ck_residual_add_token_major(), CKLayerForwardParams::embed_dim, CKLayerForwardParams::eps, CKLayerForwardParams::fc1_out, CKLayerForwardParams::head_dim, CKLayerForwardParams::input, CKLayerForwardParams::k, CKLayerForwardParams::ln1_gamma, CKLayerForwardParams::ln1_out, CKLayerForwardParams::ln1_rstd, CKLayerForwardParams::ln2_gamma, CKLayerForwardParams::ln2_out, CKLayerForwardParams::ln2_rstd, CKLayerForwardParams::mlp_out, CKLayerForwardParams::num_heads, CKLayerForwardParams::num_kv_heads, CKLayerForwardParams::output, CKLayerForwardParams::proj_scratch, CKLayerForwardParams::proj_tmp, CKLayerForwardParams::q, CKLayerForwardParams::residual1, rmsnorm_forward(), CKLayerForwardParams::rope_cos, rope_forward_qk(), CKLayerForwardParams::rope_pos_offset, CKLayerForwardParams::rope_sin, CKLayerForwardParams::scores, CKLayerForwardParams::swiglu_out, CKLayerForwardParams::tokens, CKLayerForwardParams::v, CKLayerForwardParams::w1, CKLayerForwardParams::w2, CKLayerForwardParams::wk, CKLayerForwardParams::wo, CKLayerForwardParams::wq, and CKLayerForwardParams::wv.

◆ ck_mlp_swiglu_forward()

void ck_mlp_swiglu_forward ( const float *  input,
const float *  w1,
const float *  b1,
const float *  w2,
const float *  b2,
float *  fc1_out,
float *  swiglu_out,
float *  output,
int  tokens,
int  aligned_embed_dim,
int  aligned_intermediate_dim 
)

Definition at line 952 of file ckernel_orchestration.c.

963 {
964  int up_dim = 2 * aligned_intermediate_dim;
965  gemm_blocked_serial(input, w1, b1, fc1_out,
966  tokens, up_dim, aligned_embed_dim);
967 
968  swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
969 
970  gemm_blocked_serial(swiglu_out, w2, b2, output,
971  tokens, aligned_embed_dim, aligned_intermediate_dim);
972 }
void swiglu_forward(const float *input, float *output, int tokens, int dim)

References gemm_blocked_serial(), and swiglu_forward().

Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_decode(), and ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().

◆ ck_mlp_swiglu_forward_fully_fused_token()

void ck_mlp_swiglu_forward_fully_fused_token ( const float *  input_row,
const float *  w1,
const float *  b1,
const float *  w2,
const float *  b2,
float *  output_row,
int  aligned_embed_dim,
int  aligned_intermediate_dim 
)

Definition at line 1247 of file ckernel_orchestration.c.

1255 {
1256  if (!input_row || !w1 || !w2 || !output_row) {
1257  return;
1258  }
1259 
1260  // Split w1 into gate and up projections
1261  // w1 layout: [2 * aligned_intermediate_dim, aligned_embed_dim]
1262  // First half: W_gate [aligned_intermediate_dim, aligned_embed_dim]
1263  // Second half: W_up [aligned_intermediate_dim, aligned_embed_dim]
1264  const float *w_gate = w1;
1265  const float *w_up = w1 + (size_t)aligned_intermediate_dim * (size_t)aligned_embed_dim;
1266 
1267  // Split b1 into gate and up biases (if present)
1268  const float *b_gate = b1;
1269  const float *b_up = b1 ? (b1 + aligned_intermediate_dim) : NULL;
1270 
1271  // w2 is W_down: [aligned_embed_dim, aligned_intermediate_dim]
1272  const float *w_down = w2;
1273  const float *b_down = b2;
1274 
1275  // Call the fully fused kernel - eliminates DRAM round-trip for swiglu
1276  // Uses aligned dimensions since weights are stored with alignment padding
1277  fused_mlp_swiglu_decode_v2(input_row,
1278  w_gate,
1279  w_up,
1280  w_down,
1281  b_gate,
1282  b_up,
1283  b_down,
1284  output_row,
1285  aligned_embed_dim,
1286  aligned_intermediate_dim);
1287 }
void fused_mlp_swiglu_decode_v2(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)

References fused_mlp_swiglu_decode_v2().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode_fused(), and ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().

◆ ck_mlp_swiglu_forward_fused_token()

void ck_mlp_swiglu_forward_fused_token ( const float *  input_row,
const float *  w1,
const float *  b1,
const float *  w2,
const float *  b2,
float *  swiglu_row,
float *  output_row,
int  aligned_embed_dim,
int  aligned_intermediate_dim 
)

Definition at line 1212 of file ckernel_orchestration.c.

1221 {
1222  if (!input_row || !w1 || !w2 || !swiglu_row || !output_row) {
1223  return;
1224  }
1225 
1226  const float *w_gate = w1;
1227  const float *w_up = w1 + (size_t)aligned_intermediate_dim * (size_t)aligned_embed_dim;
1228  const float *b_gate = b1;
1229  const float *b_up = b1 ? (b1 + aligned_intermediate_dim) : NULL;
1230 
1231  gemm_swiglu_fused(input_row,
1232  w_gate,
1233  w_up,
1234  b_gate,
1235  b_up,
1236  swiglu_row,
1237  /*M=*/1,
1238  /*N=*/aligned_intermediate_dim,
1239  /*K=*/aligned_embed_dim);
1240 
1241  gemm_blocked_serial(swiglu_row, w2, b2, output_row,
1242  /*M=*/1,
1243  /*N=*/aligned_embed_dim,
1244  /*K=*/aligned_intermediate_dim);
1245 }
void gemm_swiglu_fused(const float *x, const float *W_gate, const float *W_up, const float *b_gate, const float *b_up, float *output, int M, int N, int K)

References gemm_blocked_serial(), and gemm_swiglu_fused().

◆ ck_qkv_project_head_major()

void ck_qkv_project_head_major ( const float *  input,
const float *  wq,
const float *  bq,
const float *  wk,
const float *  bk,
const float *  wv,
const float *  bv,
float *  q,
float *  k,
float *  v,
int  tokens,
int  kv_stride_tokens,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim 
)

Definition at line 168 of file ckernel_orchestration.c.

179 {
180  if (!input || !wq || !wk || !wv || !q || !k || !v) {
181  return;
182  }
183  if (kv_stride_tokens < tokens) {
184  return;
185  }
186 
187  size_t head_weight_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
188  size_t q_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
189  size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
190 
191  for (int h = 0; h < num_heads; ++h) {
192  const float *wq_h = wq + (size_t)h * head_weight_stride;
193  const float *bq_h = bq ? (bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
194  float *q_h = q + (size_t)h * q_head_stride;
195 
196  gemm_blocked_serial(input, wq_h, bq_h, q_h,
197  tokens, aligned_head_dim, aligned_embed_dim);
198  }
199 
200  for (int h = 0; h < num_kv_heads; ++h) {
201  const float *wk_h = wk + (size_t)h * head_weight_stride;
202  const float *wv_h = wv + (size_t)h * head_weight_stride;
203 
204  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
205  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
206 
207  float *k_h = k + (size_t)h * kv_head_stride;
208  float *v_h = v + (size_t)h * kv_head_stride;
209 
210  gemm_blocked_serial(input, wk_h, bk_h, k_h,
211  tokens, aligned_head_dim, aligned_embed_dim);
212  gemm_blocked_serial(input, wv_h, bv_h, v_h,
213  tokens, aligned_head_dim, aligned_embed_dim);
214  }
215 }

References gemm_blocked_serial().

Referenced by ck_layer_forward_rmsnorm_swiglu().

◆ ck_qkv_project_head_major_backward()

void ck_qkv_project_head_major_backward ( const float *  d_q,
const float *  d_k,
const float *  d_v,
const float *  input,
const float *  wq,
const float *  bq,
const float *  wk,
const float *  bk,
const float *  wv,
const float *  bv,
float *  d_input,
float *  d_wq,
float *  d_bq,
float *  d_wk,
float *  d_bk,
float *  d_wv,
float *  d_bv,
float *  scratch,
int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim,
int  num_threads 
)

Definition at line 856 of file ckernel_orchestration.c.

880 {
881  if (!d_q || !d_k || !d_v || !input || !wq || !wk || !wv ||
882  !d_input || !d_wq || !d_bq || !d_wk || !d_bk || !d_wv || !d_bv || !scratch) {
883  return;
884  }
885 
886  size_t total_in = (size_t)tokens * (size_t)aligned_embed_dim;
887  for (size_t i = 0; i < total_in; ++i) {
888  d_input[i] = 0.0f;
889  }
890 
891  size_t head_weight_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
892  size_t head_out_stride = (size_t)tokens * (size_t)aligned_head_dim;
893 
894  for (int h = 0; h < num_heads; ++h) {
895  const float *d_q_h = d_q + (size_t)h * head_out_stride;
896  const float *wq_h = wq + (size_t)h * head_weight_stride;
897  float *d_wq_h = d_wq + (size_t)h * head_weight_stride;
898  float *d_bq_h = d_bq + (size_t)h * (size_t)aligned_head_dim;
899 
900  fc2_backward_kernel(d_q_h,
901  input,
902  wq_h,
903  scratch,
904  d_wq_h,
905  d_bq_h,
906  tokens,
907  aligned_embed_dim,
908  aligned_head_dim,
909  num_threads);
910  ck_add_inplace(d_input, scratch, tokens, aligned_embed_dim);
911  }
912 
913  for (int h = 0; h < num_kv_heads; ++h) {
914  const float *d_k_h = d_k + (size_t)h * head_out_stride;
915  const float *d_v_h = d_v + (size_t)h * head_out_stride;
916 
917  const float *wk_h = wk + (size_t)h * head_weight_stride;
918  const float *wv_h = wv + (size_t)h * head_weight_stride;
919 
920  float *d_wk_h = d_wk + (size_t)h * head_weight_stride;
921  float *d_wv_h = d_wv + (size_t)h * head_weight_stride;
922 
923  float *d_bk_h = d_bk + (size_t)h * (size_t)aligned_head_dim;
924  float *d_bv_h = d_bv + (size_t)h * (size_t)aligned_head_dim;
925 
926  fc2_backward_kernel(d_k_h,
927  input,
928  wk_h,
929  scratch,
930  d_wk_h,
931  d_bk_h,
932  tokens,
933  aligned_embed_dim,
934  aligned_head_dim,
935  num_threads);
936  ck_add_inplace(d_input, scratch, tokens, aligned_embed_dim);
937 
938  fc2_backward_kernel(d_v_h,
939  input,
940  wv_h,
941  scratch,
942  d_wv_h,
943  d_bv_h,
944  tokens,
945  aligned_embed_dim,
946  aligned_head_dim,
947  num_threads);
948  ck_add_inplace(d_input, scratch, tokens, aligned_embed_dim);
949  }
950 }

References ck_add_inplace(), and fc2_backward_kernel().

Referenced by ck_layer_backward_rmsnorm_swiglu().

◆ ck_qkv_project_head_major_token()

void ck_qkv_project_head_major_token ( const float *  input_row,
const float *  wq,
const float *  bq,
const float *  wk,
const float *  bk,
const float *  wv,
const float *  bv,
float *  q_token,
float *  k_token,
float *  v_token,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  aligned_head_dim 
)

Definition at line 78 of file attention_decode_fused.c.

89 {
90  if (!input_row || !wq || !wk || !wv || !q_token || !k_token || !v_token) {
91  return;
92  }
93 
94  const int q_out = num_heads * aligned_head_dim;
95  gemm_blocked_serial(input_row, wq, bq, q_token,
96  /*tokens=*/1, q_out, aligned_embed_dim);
97 
98  size_t head_weight_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
99 #pragma omp parallel for schedule(static) if(num_kv_heads > 1)
100  for (int h = 0; h < num_kv_heads; ++h) {
101  const float *wk_h = wk + (size_t)h * head_weight_stride;
102  const float *wv_h = wv + (size_t)h * head_weight_stride;
103  const float *bk_h = bk ? (bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
104  const float *bv_h = bv ? (bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
105  float *k_h = k_token + (size_t)h * (size_t)aligned_head_dim;
106  float *v_h = v_token + (size_t)h * (size_t)aligned_head_dim;
107 
108  gemm_blocked_serial(input_row, wk_h, bk_h, k_h,
109  /*tokens=*/1, aligned_head_dim, aligned_embed_dim);
110  gemm_blocked_serial(input_row, wv_h, bv_h, v_h,
111  /*tokens=*/1, aligned_head_dim, aligned_embed_dim);
112  }
113 }

References gemm_blocked_serial().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), and ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl().

◆ ck_residual_add_backward()

void ck_residual_add_backward ( const float *  d_out,
float *  d_a,
float *  d_b,
int  tokens,
int  aligned_embed_dim 
)

Definition at line 151 of file ckernel_orchestration.c.

156 {
157  if (!d_out || !d_a || !d_b) {
158  return;
159  }
160  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
161  for (size_t i = 0; i < total; ++i) {
162  float v = d_out[i];
163  d_a[i] = v;
164  d_b[i] = v;
165  }
166 }

Referenced by ck_layer_backward_rmsnorm_swiglu().

◆ ck_residual_add_token_major()

void ck_residual_add_token_major ( const float *  a,
const float *  b,
float *  out,
int  tokens,
int  aligned_embed_dim 
)