← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_kernel_specs.c
Go to the documentation of this file.
1 #include "ckernel_kernel_specs.h"
2 
4  {"token_emb", CK_SCOPE_GLOBAL, CK_ROLE_WEIGHT, { { CK_DIM_VOCAB, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
5  {"pos_emb", CK_SCOPE_GLOBAL, CK_ROLE_WEIGHT, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
6  {"embedded_input", CK_SCOPE_GLOBAL, CK_ROLE_ACTIVATION, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
7  {"rope_cos_cache", CK_SCOPE_GLOBAL, CK_ROLE_ACTIVATION, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_HEAD_DIM, 1, 2 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, "rope_theta", CK_DT_FP32},
8  {"rope_sin_cache", CK_SCOPE_GLOBAL, CK_ROLE_ACTIVATION, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_HEAD_DIM, 1, 2 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, "rope_theta", CK_DT_FP32},
9  {"final_ln_weight", CK_SCOPE_GLOBAL, CK_ROLE_WEIGHT, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
10  {"final_ln_bias", CK_SCOPE_GLOBAL, CK_ROLE_WEIGHT, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
11  {"final_ln_mean", CK_SCOPE_GLOBAL, CK_ROLE_ACTIVATION, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
12  {"final_ln_rstd", CK_SCOPE_GLOBAL, CK_ROLE_ACTIVATION, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
13  {"final_output", CK_SCOPE_GLOBAL, CK_ROLE_ACTIVATION, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
14  {"lm_head_weight", CK_SCOPE_GLOBAL, CK_ROLE_WEIGHT, { { CK_DIM_VOCAB, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, "token_emb", NULL, CK_DT_FP32},
15  {"logits", CK_SCOPE_GLOBAL, CK_ROLE_ACTIVATION, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_VOCAB, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
16  {"d_token_emb", CK_SCOPE_GLOBAL, CK_ROLE_GRAD, { { CK_DIM_VOCAB, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
17  {"d_pos_emb", CK_SCOPE_GLOBAL, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
18  {"d_final_output", CK_SCOPE_GLOBAL, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
19  {"d_final_input", CK_SCOPE_GLOBAL, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
20  {"d_final_ln_weight", CK_SCOPE_GLOBAL, CK_ROLE_GRAD, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
21  {"d_logits", CK_SCOPE_GLOBAL, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_VOCAB, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
22  {"input", CK_SCOPE_LAYER, CK_ROLE_INPUT, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
23  {"ln1_gamma", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
24  {"ln1_out", CK_SCOPE_LAYER, CK_ROLE_OUTPUT, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
25  {"ln1_rstd", CK_SCOPE_LAYER, CK_ROLE_ACTIVATION, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 1, NULL, NULL, CK_DT_FP32},
26  {"wq", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_NUM_HEADS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
27  {"bq", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_NUM_HEADS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
28  {"wk", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_NUM_KV_HEADS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
29  {"bk", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_NUM_KV_HEADS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
30  {"wv", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_NUM_KV_HEADS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
31  {"bv", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_NUM_KV_HEADS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
32  {"q", CK_SCOPE_LAYER, CK_ROLE_OUTPUT, { { CK_DIM_NUM_HEADS, 1, 1 }, { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
33  {"k", CK_SCOPE_LAYER, CK_ROLE_OUTPUT, { { CK_DIM_NUM_KV_HEADS, 1, 1 }, { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
34  {"v", CK_SCOPE_LAYER, CK_ROLE_OUTPUT, { { CK_DIM_NUM_KV_HEADS, 1, 1 }, { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
35  {"scores", CK_SCOPE_LAYER, CK_ROLE_ACTIVATION, { { CK_DIM_NUM_HEADS, 1, 1 }, { CK_DIM_ALIGNED_CTX, 1, 1 }, { CK_DIM_ALIGNED_CTX, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, "training_enabled", CK_DT_FP32},
36  {"attn_out", CK_SCOPE_LAYER, CK_ROLE_OUTPUT, { { CK_DIM_NUM_HEADS, 1, 1 }, { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
37  {"wo", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_NUM_HEADS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
38  {"bo", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
39  {"proj_tmp", CK_SCOPE_LAYER, CK_ROLE_OUTPUT, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
40  {"proj_scratch", CK_SCOPE_LAYER, CK_ROLE_SCRATCH, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
41  {"residual1", CK_SCOPE_LAYER, CK_ROLE_OUTPUT, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
42  {"ln2_gamma", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
43  {"ln2_out", CK_SCOPE_LAYER, CK_ROLE_OUTPUT, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
44  {"ln2_rstd", CK_SCOPE_LAYER, CK_ROLE_ACTIVATION, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 1, NULL, NULL, CK_DT_FP32},
45  {"w1", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_ALIGNED_INTERMEDIATE, 2, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
46  {"b1", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_ALIGNED_INTERMEDIATE, 2, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
47  {"fc1_out", CK_SCOPE_LAYER, CK_ROLE_OUTPUT, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_INTERMEDIATE, 2, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
48  {"swiglu_out", CK_SCOPE_LAYER, CK_ROLE_OUTPUT, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_INTERMEDIATE, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
49  {"w2", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_ALIGNED_INTERMEDIATE, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
50  {"b2", CK_SCOPE_LAYER, CK_ROLE_WEIGHT, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
51  {"mlp_out", CK_SCOPE_LAYER, CK_ROLE_OUTPUT, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
52  {"output", CK_SCOPE_LAYER, CK_ROLE_OUTPUT, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
53  {"d_output", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
54  {"d_residual1", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
55  {"d_mlp_out", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
56  {"d_swiglu_out", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_INTERMEDIATE, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
57  {"d_w2", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_ALIGNED_INTERMEDIATE, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
58  {"d_b2", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
59  {"d_fc1_out", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_INTERMEDIATE, 2, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
60  {"d_ln2_out", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
61  {"d_w1", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_ALIGNED_INTERMEDIATE, 2, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
62  {"d_b1", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_ALIGNED_INTERMEDIATE, 2, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
63  {"d_ln2_gamma", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
64  {"d_input", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
65  {"d_proj_tmp", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
66  {"d_attn_out", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_NUM_HEADS, 1, 1 }, { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
67  {"d_wo", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_NUM_HEADS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
68  {"d_bo", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
69  {"d_q", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_NUM_HEADS, 1, 1 }, { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
70  {"d_k", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_NUM_KV_HEADS, 1, 1 }, { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
71  {"d_v", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_NUM_KV_HEADS, 1, 1 }, { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
72  {"d_scores", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_NUM_HEADS, 1, 1 }, { CK_DIM_ALIGNED_CTX, 1, 1 }, { CK_DIM_ALIGNED_CTX, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
73  {"d_ln1_out", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_TOKENS, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
74  {"d_wq", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_NUM_HEADS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
75  {"d_bq", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_NUM_HEADS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
76  {"d_wk", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_NUM_KV_HEADS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
77  {"d_bk", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_NUM_KV_HEADS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
78  {"d_wv", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_NUM_KV_HEADS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
79  {"d_bv", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_NUM_KV_HEADS, 1, 1 }, { CK_DIM_ALIGNED_HEAD, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
80  {"d_ln1_gamma", CK_SCOPE_LAYER, CK_ROLE_GRAD, { { CK_DIM_ALIGNED_EMBED, 1, 1 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 }, { CK_DIM_END, 0, 0 } }, 0, NULL, NULL, CK_DT_FP32},
81 };
82 
84 
86  {"attention", { "attention_forward_causal_head_major_gqa", "attention_forward_causal_head_major_gqa_bf16", NULL, NULL, NULL }, { "attention_backward_causal_head_major_gqa", "attention_backward_causal_head_major_gqa_bf16", NULL, NULL, NULL }, CK_DT_MASK(CK_DT_FP32) | CK_DT_MASK(CK_DT_BF16), CK_DT_FP32, { "src/kernels/attention_kernels.c", "src/kernels/softmax_kernels.c", NULL, NULL, NULL, NULL, NULL, NULL }},
87  {"attn_proj", { "ck_attention_project_head_major", NULL, NULL, NULL, NULL }, { "ck_attention_project_head_major_backward", NULL, NULL, NULL, NULL }, CK_DT_MASK(CK_DT_FP32), CK_DT_FP32, { "src/ckernel_orchestration.c", "src/kernels/gemm_kernels.c", "src/kernels/mlp_kernels.c", "src/kernels/gelu_kernels.c", NULL, NULL, NULL, NULL }},
88  {"mlp_down", { "gemm_blocked_serial", NULL, NULL, NULL, NULL }, { "fc2_backward_kernel", NULL, NULL, NULL, NULL }, CK_DT_MASK(CK_DT_FP32), CK_DT_FP32, { "src/kernels/gemm_kernels.c", "src/kernels/mlp_kernels.c", "src/kernels/gelu_kernels.c", NULL, NULL, NULL, NULL, NULL }},
89  {"mlp_up", { "gemm_blocked_serial", NULL, NULL, NULL, NULL }, { "fc1_backward_kernel", NULL, NULL, NULL, NULL }, CK_DT_MASK(CK_DT_FP32), CK_DT_FP32, { "src/kernels/gemm_kernels.c", "src/kernels/mlp_kernels.c", "src/kernels/gelu_kernels.c", NULL, NULL, NULL, NULL, NULL }},
90  {"qkv_project", { "ck_qkv_project_head_major", NULL, NULL, NULL, NULL }, { "ck_qkv_project_head_major_backward", NULL, NULL, NULL, NULL }, CK_DT_MASK(CK_DT_FP32), CK_DT_FP32, { "src/ckernel_orchestration.c", "src/kernels/gemm_kernels.c", "src/kernels/mlp_kernels.c", "src/kernels/gelu_kernels.c", NULL, NULL, NULL, NULL }},
91  {"residual_add", { "ck_residual_add_token_major", NULL, NULL, NULL, NULL }, { "ck_residual_add_backward", NULL, NULL, NULL, NULL }, CK_DT_MASK(CK_DT_FP32), CK_DT_FP32, { "src/ckernel_orchestration.c", NULL, NULL, NULL, NULL, NULL, NULL, NULL }},
92  {"rmsnorm", { "rmsnorm_forward", "rmsnorm_forward_bf16", NULL, "rmsnorm_forward_int8", "rmsnorm_forward_int4" }, { "rmsnorm_backward", "rmsnorm_backward_bf16", NULL, "rmsnorm_backward_int8", "rmsnorm_backward_int4" }, CK_DT_MASK(CK_DT_FP32) | CK_DT_MASK(CK_DT_BF16) | CK_DT_MASK(CK_DT_INT8) | CK_DT_MASK(CK_DT_INT4), CK_DT_FP32, { "src/kernels/rmsnorm_kernels.c", "src/kernels/rmsnorm_kernels_bf16.c", "src/kernels/rmsnorm_kernels_int8.c", "src/kernels/rmsnorm_kernels_int4.c", NULL, NULL, NULL, NULL }},
93  {"rope", { "rope_forward_qk", "rope_forward_qk_bf16", NULL, NULL, NULL }, { "rope_backward_qk", "rope_backward_qk_bf16", NULL, NULL, NULL }, CK_DT_MASK(CK_DT_FP32) | CK_DT_MASK(CK_DT_BF16), CK_DT_FP32, { "src/kernels/rope_kernels.c", "src/kernels/rope_kernels_bf16.c", NULL, NULL, NULL, NULL, NULL, NULL }},
94  {"swiglu", { "swiglu_forward", "swiglu_forward_bf16", NULL, NULL, NULL }, { "swiglu_backward", "swiglu_backward_bf16", NULL, NULL, NULL }, CK_DT_MASK(CK_DT_FP32) | CK_DT_MASK(CK_DT_BF16), CK_DT_FP32, { "src/kernels/swiglu_kernels.c", "src/kernels/swiglu_kernels_bf16.c", "src/kernels/sigmoid_kernels.c", NULL, NULL, NULL, NULL, NULL }},
95 };
96 
97 const size_t ck_kernel_spec_count = sizeof(ck_kernel_specs) / sizeof(ck_kernel_specs[0]);
98 
100  {"rmsnorm", NULL},
101  {"qkv_project", NULL},
102  {"rope", "rope_theta>0"},
103  {"attention", NULL},
104  {"attn_proj", NULL},
105  {"residual_add", NULL},
106  {"rmsnorm", NULL},
107  {"mlp_up", NULL},
108  {"swiglu", NULL},
109  {"mlp_down", NULL},
110  {"residual_add", NULL},
111 };
112 
114 
116  {"residual_add", NULL},
117  {"mlp_down", NULL},
118  {"swiglu", NULL},
119  {"mlp_up", NULL},
120  {"rmsnorm", NULL},
121  {"residual_add", NULL},
122  {"attn_proj", NULL},
123  {"attention", NULL},
124  {"rope", "rope_theta>0"},
125  {"qkv_project", NULL},
126  {"rmsnorm", NULL},
127 };
128 
130 
132  {"input", "input"},
133  {"gamma", "ln1_gamma"},
134  {"out", "ln1_out"},
135  {"rstd", "ln1_rstd"},
136 };
137 
139  {"input", "ln1_out"},
140  {"wq", "wq"},
141  {"bq", "bq"},
142  {"wk", "wk"},
143  {"bk", "bk"},
144  {"wv", "wv"},
145  {"bv", "bv"},
146  {"q", "q"},
147  {"k", "k"},
148  {"v", "v"},
149 };
150 
152  {"q", "q"},
153  {"k", "k"},
154  {"cos_cache", "rope_cos_cache"},
155  {"sin_cache", "rope_sin_cache"},
156 };
157 
159  {"q", "q"},
160  {"k", "k"},
161  {"v", "v"},
162  {"scores", "scores"},
163  {"attn_out", "attn_out"},
164 };
165 
167  {"attn_out", "attn_out"},
168  {"wo", "wo"},
169  {"bo", "bo"},
170  {"proj_tmp", "proj_tmp"},
171  {"proj_scratch", "proj_scratch"},
172 };
173 
175  {"a", "input"},
176  {"b", "proj_tmp"},
177  {"out", "residual1"},
178 };
179 
181  {"input", "residual1"},
182  {"gamma", "ln2_gamma"},
183  {"out", "ln2_out"},
184  {"rstd", "ln2_rstd"},
185 };
186 
188  {"input", "ln2_out"},
189  {"w1", "w1"},
190  {"b1", "b1"},
191  {"fc1_out", "fc1_out"},
192 };
193 
195  {"fc1_out", "fc1_out"},
196  {"swiglu_out", "swiglu_out"},
197 };
198 
200  {"swiglu_out", "swiglu_out"},
201  {"w2", "w2"},
202  {"b2", "b2"},
203  {"mlp_out", "mlp_out"},
204 };
205 
207  {"a", "residual1"},
208  {"b", "mlp_out"},
209  {"out", "output"},
210 };
211 
213  {"d_out", "d_output"},
214  {"d_a", "d_residual1"},
215  {"d_b", "d_mlp_out"},
216 };
217 
219  {"d_out", "d_mlp_out"},
220  {"swiglu_out", "swiglu_out"},
221  {"w2", "w2"},
222  {"d_input", "d_swiglu_out"},
223  {"d_w2", "d_w2"},
224  {"d_b2", "d_b2"},
225 };
226 
228  {"fc1_out", "fc1_out"},
229  {"d_out", "d_swiglu_out"},
230  {"d_input", "d_fc1_out"},
231 };
232 
234  {"d_out", "d_fc1_out"},
235  {"input", "ln2_out"},
236  {"w1", "w1"},
237  {"d_input", "d_ln2_out"},
238  {"d_w1", "d_w1"},
239  {"d_b1", "d_b1"},
240 };
241 
243  {"d_out", "d_ln2_out"},
244  {"input", "residual1"},
245  {"gamma", "ln2_gamma"},
246  {"rstd", "ln2_rstd"},
247  {"d_input", "d_residual1"},
248  {"d_gamma", "d_ln2_gamma"},
249 };
250 
252  {"d_out", "d_residual1"},
253  {"d_a", "d_input"},
254  {"d_b", "d_proj_tmp"},
255 };
256 
258  {"d_out", "d_proj_tmp"},
259  {"attn_out", "attn_out"},
260  {"wo", "wo"},
261  {"d_attn_out", "d_attn_out"},
262  {"d_wo", "d_wo"},
263  {"d_bo", "d_bo"},
264 };
265 
267  {"d_out", "d_attn_out"},
268  {"q", "q"},
269  {"k", "k"},
270  {"v", "v"},
271  {"scores", "scores"},
272  {"d_q", "d_q"},
273  {"d_k", "d_k"},
274  {"d_v", "d_v"},
275  {"d_scores", "d_scores"},
276 };
277 
279  {"d_q_out", "d_q"},
280  {"d_k_out", "d_k"},
281  {"d_q", "d_q"},
282  {"d_k", "d_k"},
283  {"cos_cache", "rope_cos_cache"},
284  {"sin_cache", "rope_sin_cache"},
285 };
286 
288  {"d_q", "d_q"},
289  {"d_k", "d_k"},
290  {"d_v", "d_v"},
291  {"input", "ln1_out"},
292  {"wq", "wq"},
293  {"wk", "wk"},
294  {"wv", "wv"},
295  {"d_input", "d_ln1_out"},
296  {"d_wq", "d_wq"},
297  {"d_bq", "d_bq"},
298  {"d_wk", "d_wk"},
299  {"d_bk", "d_bk"},
300  {"d_wv", "d_wv"},
301  {"d_bv", "d_bv"},
302 };
303 
305  {"d_out", "d_ln1_out"},
306  {"input", "input"},
307  {"gamma", "ln1_gamma"},
308  {"rstd", "ln1_rstd"},
309  {"d_input", "d_input"},
310  {"d_gamma", "d_ln1_gamma"},
311 };
312 
314  {"rmsnorm", NULL, ck_decoder_forward_bindings_0, 4},
315  {"qkv_project", NULL, ck_decoder_forward_bindings_1, 10},
316  {"rope", "rope_theta>0", ck_decoder_forward_bindings_2, 4},
317  {"attention", NULL, ck_decoder_forward_bindings_3, 5},
318  {"attn_proj", NULL, ck_decoder_forward_bindings_4, 5},
319  {"residual_add", NULL, ck_decoder_forward_bindings_5, 3},
320  {"rmsnorm", NULL, ck_decoder_forward_bindings_6, 4},
321  {"mlp_up", NULL, ck_decoder_forward_bindings_7, 4},
322  {"swiglu", NULL, ck_decoder_forward_bindings_8, 2},
323  {"mlp_down", NULL, ck_decoder_forward_bindings_9, 4},
324  {"residual_add", NULL, ck_decoder_forward_bindings_10, 3},
325 };
326 
328 
330  {"residual_add", NULL, ck_decoder_backward_bindings_0, 3},
331  {"mlp_down", NULL, ck_decoder_backward_bindings_1, 6},
332  {"swiglu", NULL, ck_decoder_backward_bindings_2, 3},
333  {"mlp_up", NULL, ck_decoder_backward_bindings_3, 6},
334  {"rmsnorm", NULL, ck_decoder_backward_bindings_4, 6},
335  {"residual_add", NULL, ck_decoder_backward_bindings_5, 3},
336  {"attn_proj", NULL, ck_decoder_backward_bindings_6, 6},
337  {"attention", NULL, ck_decoder_backward_bindings_7, 9},
338  {"rope", "rope_theta>0", ck_decoder_backward_bindings_8, 6},
339  {"qkv_project", NULL, ck_decoder_backward_bindings_9, 14},
340  {"rmsnorm", NULL, ck_decoder_backward_bindings_10, 6},
341 };
342 
#define CK_DT_MASK(dt)
Definition: ckernel_dtype.h:53
@ CK_DT_FP32
Definition: ckernel_dtype.h:29
@ CK_DT_INT4
Definition: ckernel_dtype.h:35
@ CK_DT_BF16
Definition: ckernel_dtype.h:30
@ CK_DT_INT8
Definition: ckernel_dtype.h:34
static const CKPlanBinding ck_decoder_forward_bindings_7[]
static const CKPlanBinding ck_decoder_forward_bindings_9[]
static const CKPlanBinding ck_decoder_backward_bindings_9[]
static const CKPlanBinding ck_decoder_backward_bindings_2[]
const CKPlanStepV2 ck_decoder_forward_plan_v2[]
static const CKPlanBinding ck_decoder_forward_bindings_1[]
static const CKPlanBinding ck_decoder_backward_bindings_6[]
const CKPlanStep ck_decoder_forward_plan[]
static const CKPlanBinding ck_decoder_forward_bindings_6[]
static const CKPlanBinding ck_decoder_forward_bindings_0[]
const size_t ck_decoder_backward_plan_count
const size_t ck_decoder_forward_plan_v2_count
static const CKPlanBinding ck_decoder_backward_bindings_0[]
static const CKPlanBinding ck_decoder_backward_bindings_10[]
static const CKPlanBinding ck_decoder_forward_bindings_8[]
static const CKPlanBinding ck_decoder_forward_bindings_3[]
static const CKPlanBinding ck_decoder_forward_bindings_4[]
static const CKPlanBinding ck_decoder_backward_bindings_7[]
const size_t ck_decoder_backward_plan_v2_count
static const CKPlanBinding ck_decoder_forward_bindings_5[]
const size_t ck_decoder_forward_plan_count
static const CKPlanBinding ck_decoder_backward_bindings_5[]
const CKPlanStep ck_decoder_backward_plan[]
static const CKPlanBinding ck_decoder_forward_bindings_10[]
static const CKPlanBinding ck_decoder_forward_bindings_2[]
const CKKernelSpec ck_kernel_specs[]
static const CKPlanBinding ck_decoder_backward_bindings_8[]
static const CKPlanBinding ck_decoder_backward_bindings_1[]
static const CKPlanBinding ck_decoder_backward_bindings_3[]
const CKPlanStepV2 ck_decoder_backward_plan_v2[]
const CKBufferSpec ck_decoder_buffers[]
const size_t ck_kernel_spec_count
const size_t ck_decoder_buffer_count
static const CKPlanBinding ck_decoder_backward_bindings_4[]
@ CK_ROLE_WEIGHT
@ CK_ROLE_SCRATCH
@ CK_ROLE_GRAD
@ CK_ROLE_ACTIVATION
@ CK_ROLE_INPUT
@ CK_ROLE_OUTPUT
@ CK_DIM_ALIGNED_INTERMEDIATE
@ CK_DIM_NUM_HEADS
@ CK_DIM_ALIGNED_EMBED
@ CK_DIM_TOKENS
@ CK_DIM_ALIGNED_CTX
@ CK_DIM_END
@ CK_DIM_ALIGNED_HEAD
@ CK_DIM_HEAD_DIM
@ CK_DIM_NUM_KV_HEADS
@ CK_DIM_VOCAB
@ CK_SCOPE_LAYER
@ CK_SCOPE_GLOBAL