45 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
46 #include <immintrin.h>
84 if (!q_token || !k_cache || !v_cache || !out_token) {
87 if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
90 if (kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
94 static int use_strict = -1;
96 const char *env = getenv(
"CK_FLASH_ATTN_STRICT");
97 use_strict = (env && env[0] && env[0] !=
'0') ? 1 : 0;
115 const float scale = 1.0f / sqrtf((
float)head_dim);
116 const size_t head_stride = (size_t)cache_capacity * (
size_t)aligned_head_dim;
118 #pragma omp parallel for schedule(static) if(num_heads > 1)
119 for (
int h = 0; h < num_heads; ++h) {
120 const int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
121 const float *q_head = q_token + (size_t)h * (
size_t)aligned_head_dim;
122 const float *k_head = k_cache + (size_t)kv_head * head_stride;
123 const float *v_head = v_cache + (size_t)kv_head * head_stride;
124 float *out_head = out_token + (size_t)h * (
size_t)aligned_head_dim;
143 int aligned_embed_dim)
145 size_t total = (size_t)tokens * (
size_t)aligned_embed_dim;
146 for (
size_t i = 0; i < total; ++i) {
147 out[i] = a[i] + b[i];
155 int aligned_embed_dim)
157 if (!d_out || !d_a || !d_b) {
160 size_t total = (size_t)tokens * (
size_t)aligned_embed_dim;
161 for (
size_t i = 0; i < total; ++i) {
169 const float *wq,
const float *bq,
170 const float *wk,
const float *bk,
171 const float *wv,
const float *bv,
172 float *q,
float *k,
float *v,
174 int kv_stride_tokens,
175 int aligned_embed_dim,
178 int aligned_head_dim)
180 if (!input || !wq || !wk || !wv || !q || !k || !v) {
183 if (kv_stride_tokens < tokens) {
187 size_t head_weight_stride = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
188 size_t q_head_stride = (size_t)tokens * (
size_t)aligned_head_dim;
189 size_t kv_head_stride = (size_t)kv_stride_tokens * (
size_t)aligned_head_dim;
191 for (
int h = 0; h < num_heads; ++h) {
192 const float *wq_h = wq + (size_t)h * head_weight_stride;
193 const float *bq_h = bq ? (bq + (size_t)h * (
size_t)aligned_head_dim) : NULL;
194 float *q_h = q + (size_t)h * q_head_stride;
197 tokens, aligned_head_dim, aligned_embed_dim);
200 for (
int h = 0; h < num_kv_heads; ++h) {
201 const float *wk_h = wk + (size_t)h * head_weight_stride;
202 const float *wv_h = wv + (size_t)h * head_weight_stride;
204 const float *bk_h = bk ? (bk + (size_t)h * (
size_t)aligned_head_dim) : NULL;
205 const float *bv_h = bv ? (bv + (size_t)h * (
size_t)aligned_head_dim) : NULL;
207 float *k_h = k + (size_t)h * kv_head_stride;
208 float *v_h = v + (size_t)h * kv_head_stride;
211 tokens, aligned_head_dim, aligned_embed_dim);
213 tokens, aligned_head_dim, aligned_embed_dim);
219 static int cached = -2;
223 const char *env = getenv(
"CK_LAYER_DEBUG");
224 if (env && (env[0] ==
'1' || env[0] ==
'y' || env[0] ==
'Y')) {
237 int nan_count = 0, inf_count = 0;
238 float min_val = 1e38f, max_val = -1e38f;
239 for (
int i = 0; i < size; ++i) {
243 }
else if (isinf(v)) {
246 if (v < min_val) min_val = v;
247 if (v > max_val) max_val = v;
250 if (nan_count > 0 || inf_count > 0) {
251 fprintf(stderr,
"[LAYER_DEBUG] %-30s size=%5d nan=%d inf=%d\n",
252 stage, size, nan_count, inf_count);
254 fprintf(stderr,
"[LAYER_DEBUG] %-30s size=%5d range=[%.3e, %.3e]\n",
255 stage, size, min_val, max_val);
265 int nan_scale = 0, inf_scale = 0;
266 float min_d = 1e38f, max_d = -1e38f;
267 for (
int i = 0; i < num_blocks; ++i) {
268 float d = blocks[i].
d;
271 }
else if (isinf(d)) {
274 if (d < min_d) min_d = d;
275 if (d > max_d) max_d = d;
278 if (nan_scale > 0 || inf_scale > 0) {
279 fprintf(stderr,
"[LAYER_DEBUG] %-30s blocks=%d nan_scale=%d inf_scale=%d\n",
280 stage, num_blocks, nan_scale, inf_scale);
282 fprintf(stderr,
"[LAYER_DEBUG] %-30s blocks=%d scale_range=[%.3e, %.3e]\n",
283 stage, num_blocks, min_d, max_d);
293 int nan_d = 0, nan_dmin = 0;
294 float min_d = 1e38f, max_d = -1e38f;
295 for (
int i = 0; i < num_blocks; ++i) {
298 if (isnan(d)) nan_d++;
299 if (isnan(dm)) nan_dmin++;
300 if (!isnan(d) && !isinf(d)) {
301 if (d < min_d) min_d = d;
302 if (d > max_d) max_d = d;
305 if (nan_d > 0 || nan_dmin > 0) {
306 fprintf(stderr,
"[LAYER_DEBUG] %-30s blocks=%d nan_d=%d nan_dmin=%d\n",
307 stage, num_blocks, nan_d, nan_dmin);
309 fprintf(stderr,
"[LAYER_DEBUG] %-30s blocks=%d d_range=[%.3e, %.3e]\n",
310 stage, num_blocks, min_d, max_d);
316 static int cached = -2;
321 const char *env = getenv(
"CK_Q8K_ACTIVATIONS");
322 if (!env || !env[0]) {
326 if (env[0] ==
'0' || env[0] ==
'n' || env[0] ==
'N' ||
327 env[0] ==
'f' || env[0] ==
'F') {
388 const void *wq,
const float *bq,
389 const void *wk,
const float *bk,
390 const void *wv,
const float *bv,
391 float *q,
float *k,
float *v,
393 int kv_stride_tokens,
394 int aligned_embed_dim,
397 int aligned_head_dim)
399 if (!input || !wq || !wk || !wv || !q || !k || !v) {
402 if (kv_stride_tokens < tokens) {
406 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
408 const size_t q_head_stride = (size_t)tokens * (
size_t)aligned_head_dim;
409 const size_t kv_head_stride = (size_t)kv_stride_tokens * (
size_t)aligned_head_dim;
411 const uint8_t *wq_bytes = (
const uint8_t *)wq;
412 const uint8_t *wk_bytes = (
const uint8_t *)wk;
413 const uint8_t *wv_bytes = (
const uint8_t *)wv;
415 for (
int h = 0; h < num_heads; ++h) {
416 const void *wq_h = wq_bytes + (size_t)h * head_w_bytes;
417 const float *bq_h = bq ? (bq + (size_t)h * (
size_t)aligned_head_dim) : NULL;
418 float *q_h = q + (size_t)h * q_head_stride;
421 tokens, aligned_head_dim, aligned_embed_dim);
424 for (
int h = 0; h < num_kv_heads; ++h) {
425 const void *wk_h = wk_bytes + (size_t)h * head_w_bytes;
426 const void *wv_h = wv_bytes + (size_t)h * head_w_bytes;
428 const float *bk_h = bk ? (bk + (size_t)h * (
size_t)aligned_head_dim) : NULL;
429 const float *bv_h = bv ? (bv + (size_t)h * (
size_t)aligned_head_dim) : NULL;
431 float *k_h = k + (size_t)h * kv_head_stride;
432 float *v_h = v + (size_t)h * kv_head_stride;
435 tokens, aligned_head_dim, aligned_embed_dim);
437 tokens, aligned_head_dim, aligned_embed_dim);
442 const void *wq,
const float *bq,
CKDataType wq_dtype,
443 const void *wk,
const float *bk,
CKDataType wk_dtype,
444 const void *wv,
const float *bv,
CKDataType wv_dtype,
445 float *q,
float *k,
float *v,
447 int kv_stride_tokens,
448 int aligned_embed_dim,
451 int aligned_head_dim)
453 if (!input || !wq || !wk || !wv || !q || !k || !v) {
456 if (kv_stride_tokens < tokens) {
460 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
461 const size_t q_head_stride = (size_t)tokens * (
size_t)aligned_head_dim;
462 const size_t kv_head_stride = (size_t)kv_stride_tokens * (
size_t)aligned_head_dim;
468 const uint8_t *wq_bytes = (
const uint8_t *)wq;
469 const uint8_t *wk_bytes = (
const uint8_t *)wk;
470 const uint8_t *wv_bytes = (
const uint8_t *)wv;
472 for (
int h = 0; h < num_heads; ++h) {
474 ? (
const void *)((
const float *)wq + (
size_t)h * head_w_elems)
475 : (
const void *)(wq_bytes + (size_t)h * wq_head_bytes);
476 const float *bq_h = bq ? (bq + (size_t)h * (
size_t)aligned_head_dim) : NULL;
477 float *q_h = q + (size_t)h * q_head_stride;
480 tokens, aligned_head_dim, aligned_embed_dim, wq_dtype);
483 for (
int h = 0; h < num_kv_heads; ++h) {
485 ? (
const void *)((
const float *)wk + (
size_t)h * head_w_elems)
486 : (
const void *)(wk_bytes + (size_t)h * wk_head_bytes);
488 ? (
const void *)((
const float *)wv + (
size_t)h * head_w_elems)
489 : (
const void *)(wv_bytes + (size_t)h * wv_head_bytes);
491 const float *bk_h = bk ? (bk + (size_t)h * (
size_t)aligned_head_dim) : NULL;
492 const float *bv_h = bv ? (bv + (size_t)h * (
size_t)aligned_head_dim) : NULL;
494 float *k_h = k + (size_t)h * kv_head_stride;
495 float *v_h = v + (size_t)h * kv_head_stride;
498 tokens, aligned_head_dim, aligned_embed_dim, wk_dtype);
500 tokens, aligned_head_dim, aligned_embed_dim, wv_dtype);
510 int aligned_embed_dim,
512 int aligned_head_dim)
514 if (!attn_out || !wo || !out || !scratch) {
519 const int K = num_heads * aligned_head_dim;
520 if (K != aligned_embed_dim) {
524 const size_t head_in_stride = (size_t)tokens * (
size_t)aligned_head_dim;
526 for (
int t = 0; t < tokens; ++t) {
527 float *dst = scratch + (size_t)t * (
size_t)aligned_embed_dim;
528 for (
int h = 0; h < num_heads; ++h) {
529 const float *src = attn_out + (size_t)h * head_in_stride + (
size_t)t * (size_t)aligned_head_dim;
530 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
532 (
size_t)aligned_head_dim *
sizeof(
float));
537 tokens, aligned_embed_dim, aligned_embed_dim);
546 int aligned_embed_dim,
548 int aligned_head_dim,
551 if (!attn_out || !wo || !out || !scratch) {
568 const int K = num_heads * aligned_head_dim;
569 if (K != aligned_embed_dim) {
573 const size_t head_in_stride = (size_t)tokens * (
size_t)aligned_head_dim;
575 for (
int t = 0; t < tokens; ++t) {
576 float *dst = scratch + (size_t)t * (
size_t)aligned_embed_dim;
577 for (
int h = 0; h < num_heads; ++h) {
578 const float *src = attn_out + (size_t)h * head_in_stride + (
size_t)t * (size_t)aligned_head_dim;
579 memcpy(dst + (
size_t)h * (
size_t)aligned_head_dim,
581 (
size_t)aligned_head_dim *
sizeof(
float));
586 tokens, aligned_embed_dim, aligned_embed_dim, wo_dtype);
598 int aligned_embed_dim,
599 int aligned_intermediate_dim)
601 int up_dim = 2 * aligned_intermediate_dim;
603 tokens, up_dim, aligned_embed_dim);
605 swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
608 tokens, aligned_embed_dim, aligned_intermediate_dim);
622 int aligned_embed_dim,
623 int aligned_intermediate_dim)
625 int up_dim = 2 * aligned_intermediate_dim;
627 tokens, up_dim, aligned_embed_dim, w1_dtype);
629 swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
632 tokens, aligned_embed_dim, aligned_intermediate_dim, w2_dtype);
643 int aligned_embed_dim,
644 int aligned_intermediate_dim)
646 if (!input || !w1 || !w2 || !fc1_out || !swiglu_out || !output) {
649 if ((aligned_embed_dim %
QK_K) != 0 || (aligned_intermediate_dim %
QK_K) != 0) {
653 const int up_dim = 2 * aligned_intermediate_dim;
654 const int q8_blocks_embed = aligned_embed_dim /
QK_K;
655 const int q8_blocks_inter = aligned_intermediate_dim /
QK_K;
656 const int q8_blocks_max = (q8_blocks_embed > q8_blocks_inter) ? q8_blocks_embed : q8_blocks_inter;
661 1, up_dim, aligned_embed_dim);
667 1, aligned_embed_dim, aligned_intermediate_dim);
671 const float *wq,
const float *bq,
672 const float *wk,
const float *bk,
673 const float *wv,
const float *bv,
674 float *q,
float *k,
float *v,
676 int kv_stride_tokens,
677 int aligned_embed_dim,
680 int aligned_head_dim)
682 if (!input || !wq || !wk || !wv || !q || !k || !v) {
685 if (kv_stride_tokens < tokens) {
689 size_t head_weight_stride = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
690 size_t q_head_stride = (size_t)tokens * (
size_t)aligned_head_dim;
691 size_t kv_head_stride = (size_t)kv_stride_tokens * (
size_t)aligned_head_dim;
693 for (
int h = 0; h < num_heads; ++h) {
694 const float *wq_h = wq + (size_t)h * head_weight_stride;
695 const float *bq_h = bq ? (bq + (size_t)h * (
size_t)aligned_head_dim) : NULL;
696 float *q_h = q + (size_t)h * q_head_stride;
699 tokens, aligned_head_dim, aligned_embed_dim);
702 for (
int h = 0; h < num_kv_heads; ++h) {
703 const float *wk_h = wk + (size_t)h * head_weight_stride;
704 const float *wv_h = wv + (size_t)h * head_weight_stride;
706 const float *bk_h = bk ? (bk + (size_t)h * (
size_t)aligned_head_dim) : NULL;
707 const float *bv_h = bv ? (bv + (size_t)h * (
size_t)aligned_head_dim) : NULL;
709 float *k_h = k + (size_t)h * kv_head_stride;
710 float *v_h = v + (size_t)h * kv_head_stride;
713 tokens, aligned_head_dim, aligned_embed_dim);
715 tokens, aligned_head_dim, aligned_embed_dim);
722 int aligned_embed_dim)
724 size_t total = (size_t)tokens * (
size_t)aligned_embed_dim;
725 for (
size_t i = 0; i < total; ++i) {
736 int aligned_embed_dim,
738 int aligned_head_dim)
740 if (!attn_out || !wo || !out) {
743 if (num_heads > 1 && !scratch) {
747 size_t head_in_stride = (size_t)tokens * (
size_t)aligned_head_dim;
748 size_t head_weight_stride = (size_t)aligned_embed_dim * (
size_t)aligned_head_dim;
750 for (
int h = 0; h < num_heads; ++h) {
751 const float *head_in = attn_out + (size_t)h * head_in_stride;
752 const float *wo_h = wo + (size_t)h * head_weight_stride;
756 tokens, aligned_embed_dim, aligned_head_dim);
759 tokens, aligned_embed_dim, aligned_head_dim);
771 int aligned_embed_dim,
773 int aligned_head_dim)
775 if (!attn_out || !wo || !out) {
778 if (num_heads > 1 && !scratch) {
782 size_t head_in_stride = (size_t)tokens * (
size_t)aligned_head_dim;
783 size_t head_weight_stride = (size_t)aligned_embed_dim * (
size_t)aligned_head_dim;
785 for (
int h = 0; h < num_heads; ++h) {
786 const float *head_in = attn_out + (size_t)h * head_in_stride;
787 const float *wo_h = wo + (size_t)h * head_weight_stride;
791 tokens, aligned_embed_dim, aligned_head_dim);
794 tokens, aligned_embed_dim, aligned_head_dim);
801 const float *attn_out,
807 int aligned_embed_dim,
809 int aligned_head_dim)
811 if (!d_out || !attn_out || !wo || !d_attn_out || !d_wo || !d_bo) {
816 for (
int d = 0; d < aligned_embed_dim; ++d) {
819 for (
int t = 0; t < tokens; ++t) {
820 const float *row = d_out + (size_t)t * (
size_t)aligned_embed_dim;
821 for (
int d = 0; d < aligned_embed_dim; ++d) {
826 size_t head_in_stride = (size_t)tokens * (
size_t)aligned_head_dim;
827 size_t head_weight_stride = (size_t)aligned_embed_dim * (
size_t)aligned_head_dim;
829 float *tmp_b = (
float *)calloc((
size_t)aligned_embed_dim,
sizeof(float));
834 for (
int h = 0; h < num_heads; ++h) {
835 const float *head_in = attn_out + (size_t)h * head_in_stride;
836 const float *wo_h = wo + (size_t)h * head_weight_stride;
837 float *d_head_in = d_attn_out + (size_t)h * head_in_stride;
838 float *d_wo_h = d_wo + (size_t)h * head_weight_stride;
840 memset(tmp_b, 0, (
size_t)aligned_embed_dim *
sizeof(
float));
875 int aligned_embed_dim,
878 int aligned_head_dim,
881 if (!d_q || !d_k || !d_v || !input || !wq || !wk || !wv ||
882 !d_input || !d_wq || !d_bq || !d_wk || !d_bk || !d_wv || !d_bv || !scratch) {
886 size_t total_in = (size_t)tokens * (
size_t)aligned_embed_dim;
887 for (
size_t i = 0; i < total_in; ++i) {
891 size_t head_weight_stride = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
892 size_t head_out_stride = (size_t)tokens * (
size_t)aligned_head_dim;
894 for (
int h = 0; h < num_heads; ++h) {
895 const float *d_q_h = d_q + (size_t)h * head_out_stride;
896 const float *wq_h = wq + (size_t)h * head_weight_stride;
897 float *d_wq_h = d_wq + (size_t)h * head_weight_stride;
898 float *d_bq_h = d_bq + (size_t)h * (
size_t)aligned_head_dim;
913 for (
int h = 0; h < num_kv_heads; ++h) {
914 const float *d_k_h = d_k + (size_t)h * head_out_stride;
915 const float *d_v_h = d_v + (size_t)h * head_out_stride;
917 const float *wk_h = wk + (size_t)h * head_weight_stride;
918 const float *wv_h = wv + (size_t)h * head_weight_stride;
920 float *d_wk_h = d_wk + (size_t)h * head_weight_stride;
921 float *d_wv_h = d_wv + (size_t)h * head_weight_stride;
923 float *d_bk_h = d_bk + (size_t)h * (
size_t)aligned_head_dim;
924 float *d_bv_h = d_bv + (size_t)h * (
size_t)aligned_head_dim;
961 int aligned_embed_dim,
962 int aligned_intermediate_dim)
964 int up_dim = 2 * aligned_intermediate_dim;
966 tokens, up_dim, aligned_embed_dim);
968 swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
971 tokens, aligned_embed_dim, aligned_intermediate_dim);
983 int aligned_embed_dim,
984 int aligned_intermediate_dim)
986 int up_dim = 2 * aligned_intermediate_dim;
988 tokens, up_dim, aligned_embed_dim);
990 swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
993 tokens, aligned_embed_dim, aligned_intermediate_dim);
1219 int aligned_embed_dim,
1220 int aligned_intermediate_dim)
1222 if (!input_row || !w1 || !w2 || !swiglu_row || !output_row) {
1226 const float *w_gate = w1;
1227 const float *w_up = w1 + (size_t)aligned_intermediate_dim * (
size_t)aligned_embed_dim;
1228 const float *b_gate = b1;
1229 const float *b_up = b1 ? (b1 + aligned_intermediate_dim) : NULL;
1238 aligned_intermediate_dim,
1244 aligned_intermediate_dim);
1253 int aligned_embed_dim,
1254 int aligned_intermediate_dim)
1256 if (!input_row || !w1 || !w2 || !output_row) {
1264 const float *w_gate = w1;
1265 const float *w_up = w1 + (size_t)aligned_intermediate_dim * (
size_t)aligned_embed_dim;
1268 const float *b_gate = b1;
1269 const float *b_up = b1 ? (b1 + aligned_intermediate_dim) : NULL;
1272 const float *w_down = w2;
1273 const float *b_down = b2;
1286 aligned_intermediate_dim);
1297 !p->
wq || !p->
wk || !p->
wv || !p->
wo || !p->
w1 || !p->
w2 ||
1302 if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
1318 const size_t token_slot = 0;
1319 const float *input_row = p->
input + token_slot * (size_t)aligned_D;
1320 float *ln1_row = p->
ln1_out + token_slot * (size_t)aligned_D;
1321 float *ln2_row = p->
ln2_out + token_slot * (size_t)aligned_D;
1322 float *proj_row = p->
proj_tmp + token_slot * (size_t)aligned_D;
1323 float *residual_row = p->
residual1 + token_slot * (size_t)aligned_D;
1324 float *mlp_row = p->
mlp_out + token_slot * (size_t)aligned_D;
1325 float *out_row = p->
output + token_slot * (size_t)aligned_D;
1327 float ln1_rstd_tmp = 0.0f;
1328 float ln2_rstd_tmp = 0.0f;
1329 float *ln1_rstd = p->
ln1_rstd ? (p->
ln1_rstd + token_slot) : &ln1_rstd_tmp;
1330 float *ln2_rstd = p->
ln2_rstd ? (p->
ln2_rstd + token_slot) : &ln2_rstd_tmp;
1333 size_t q_elems = (size_t)H * (
size_t)ad;
1334 size_t kv_elems = (size_t)H_kv * (
size_t)ad;
1335 float q_token[q_elems];
1336 float k_token[kv_elems];
1337 float v_token[kv_elems];
1338 float attn_token[q_elems];
1355 q_token, k_token, v_token,
1425 int up_dim = 2 * aligned_intermediate;
1426 float *fc1_row = p->
fc1_out + token_slot * (size_t)up_dim;
1427 float *swiglu_row = p->
swiglu_out + token_slot * (size_t)aligned_intermediate;
1439 aligned_intermediate);
1457 !p->
wq || !p->
wk || !p->
wv || !p->
wo || !p->
w1 || !p->
w2 ||
1462 if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
1478 const size_t token_slot = 0;
1479 const float *input_row = p->
input + token_slot * (size_t)aligned_D;
1480 float *ln1_row = p->
ln1_out + token_slot * (size_t)aligned_D;
1481 float *ln2_row = p->
ln2_out + token_slot * (size_t)aligned_D;
1482 float *proj_row = p->
proj_tmp + token_slot * (size_t)aligned_D;
1483 float *residual_row = p->
residual1 + token_slot * (size_t)aligned_D;
1484 float *swiglu_row = p->
swiglu_out + token_slot * (size_t)aligned_intermediate;
1485 float *mlp_row = p->
mlp_out + token_slot * (size_t)aligned_D;
1486 float *out_row = p->
output + token_slot * (size_t)aligned_D;
1488 float ln1_rstd_tmp = 0.0f;
1489 float ln2_rstd_tmp = 0.0f;
1490 float *ln1_rstd = p->
ln1_rstd ? (p->
ln1_rstd + token_slot) : &ln1_rstd_tmp;
1491 float *ln2_rstd = p->
ln2_rstd ? (p->
ln2_rstd + token_slot) : &ln2_rstd_tmp;
1494 size_t q_elems = (size_t)H * (
size_t)ad;
1495 size_t kv_elems = (size_t)H_kv * (
size_t)ad;
1496 float q_token[q_elems];
1497 float k_token[kv_elems];
1498 float v_token[kv_elems];
1499 float attn_token[q_elems];
1516 q_token, k_token, v_token,
1594 aligned_intermediate);
1605 const void *wq,
const float *bq,
1606 const void *wk,
const float *bk,
1607 const void *wv,
const float *bv,
1611 int aligned_embed_dim,
1614 int aligned_head_dim)
1616 if (!input_row || !wq || !wk || !wv || !q_token || !k_token || !v_token) {
1620 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1623 const uint8_t *wq_bytes = (
const uint8_t *)wq;
1624 const uint8_t *wk_bytes = (
const uint8_t *)wk;
1625 const uint8_t *wv_bytes = (
const uint8_t *)wv;
1627 for (
int h = 0; h < num_heads; ++h) {
1628 const void *wq_h = wq_bytes + (size_t)h * head_w_bytes;
1629 const float *bq_h = bq ? (bq + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1630 float *q_h = q_token + (size_t)h * (
size_t)aligned_head_dim;
1632 1, aligned_head_dim, aligned_embed_dim);
1635 for (
int h = 0; h < num_kv_heads; ++h) {
1636 const void *wk_h = wk_bytes + (size_t)h * head_w_bytes;
1637 const void *wv_h = wv_bytes + (size_t)h * head_w_bytes;
1638 const float *bk_h = bk ? (bk + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1639 const float *bv_h = bv ? (bv + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1640 float *k_h = k_token + (size_t)h * (
size_t)aligned_head_dim;
1641 float *v_h = v_token + (size_t)h * (
size_t)aligned_head_dim;
1643 1, aligned_head_dim, aligned_embed_dim);
1645 1, aligned_head_dim, aligned_embed_dim);
1650 const void *wq,
const float *bq,
1651 const void *wk,
const float *bk,
1652 const void *wv,
const float *bv,
1656 int aligned_embed_dim,
1659 int aligned_head_dim)
1661 if (!input_q8 || !wq || !wk || !wv || !q_token || !k_token || !v_token) {
1665 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1668 const uint8_t *wq_bytes = (
const uint8_t *)wq;
1669 const uint8_t *wk_bytes = (
const uint8_t *)wk;
1670 const uint8_t *wv_bytes = (
const uint8_t *)wv;
1672 for (
int h = 0; h < num_heads; ++h) {
1673 const void *wq_h = wq_bytes + (size_t)h * head_w_bytes;
1674 const float *bq_h = bq ? (bq + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1675 float *q_h = q_token + (size_t)h * (
size_t)aligned_head_dim;
1677 1, aligned_head_dim, aligned_embed_dim);
1680 for (
int h = 0; h < num_kv_heads; ++h) {
1681 const void *wk_h = wk_bytes + (size_t)h * head_w_bytes;
1682 const void *wv_h = wv_bytes + (size_t)h * head_w_bytes;
1683 const float *bk_h = bk ? (bk + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1684 const float *bv_h = bv ? (bv + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1685 float *k_h = k_token + (size_t)h * (
size_t)aligned_head_dim;
1686 float *v_h = v_token + (size_t)h * (
size_t)aligned_head_dim;
1688 1, aligned_head_dim, aligned_embed_dim);
1690 1, aligned_head_dim, aligned_embed_dim);
1695 const void *wq,
const float *bq,
1696 const void *wk,
const float *bk,
1697 const void *wv,
const float *bv,
1698 float *q,
float *k,
float *v,
1700 int kv_stride_tokens,
1701 int aligned_embed_dim,
1704 int aligned_head_dim)
1706 if (!input || !wq || !wk || !wv || !q || !k || !v) {
1709 if (tokens <= 0 || aligned_embed_dim <= 0) {
1712 if (kv_stride_tokens < tokens) {
1715 if ((aligned_embed_dim %
QK_K) != 0) {
1719 const int q8_blocks = aligned_embed_dim /
QK_K;
1721 const size_t q_head_stride = (size_t)tokens * (
size_t)aligned_head_dim;
1722 const size_t kv_head_stride = (size_t)kv_stride_tokens * (
size_t)aligned_head_dim;
1724 float q_token[num_heads * aligned_head_dim];
1725 float k_token[num_kv_heads * aligned_head_dim];
1726 float v_token[num_kv_heads * aligned_head_dim];
1728 for (
int t = 0; t < tokens; ++t) {
1729 const float *input_row = input + (size_t)t * (
size_t)aligned_embed_dim;
1744 for (
int h = 0; h < num_heads; ++h) {
1745 float *q_dst = q + (size_t)h * q_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1747 q_token + (
size_t)h * (
size_t)aligned_head_dim,
1748 (
size_t)aligned_head_dim *
sizeof(
float));
1751 for (
int h = 0; h < num_kv_heads; ++h) {
1752 float *k_dst = k + (size_t)h * kv_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1753 float *v_dst = v + (size_t)h * kv_head_stride + (
size_t)t * (size_t)aligned_head_dim;
1755 k_token + (
size_t)h * (
size_t)aligned_head_dim,
1756 (
size_t)aligned_head_dim *
sizeof(
float));
1758 v_token + (
size_t)h * (
size_t)aligned_head_dim,
1759 (
size_t)aligned_head_dim *
sizeof(
float));
1769 int aligned_embed_dim,
1771 int aligned_head_dim)
1773 if (!attn_out || !wo || !out) {
1776 if (tokens <= 0 || aligned_embed_dim <= 0) {
1779 if ((aligned_embed_dim %
QK_K) != 0) {
1783 const int K = num_heads * aligned_head_dim;
1784 if (K != aligned_embed_dim) {
1788 const int q8_blocks = aligned_embed_dim /
QK_K;
1790 float attn_token[aligned_embed_dim];
1791 const size_t head_stride = (size_t)tokens * (
size_t)aligned_head_dim;
1793 for (
int t = 0; t < tokens; ++t) {
1794 for (
int h = 0; h < num_heads; ++h) {
1795 const float *src = attn_out + (size_t)h * head_stride + (
size_t)t * (size_t)aligned_head_dim;
1796 memcpy(attn_token + (
size_t)h * (
size_t)aligned_head_dim,
1798 (
size_t)aligned_head_dim *
sizeof(
float));
1803 out + (
size_t)t * (
size_t)aligned_embed_dim,
1804 1, aligned_embed_dim, aligned_embed_dim);
1817 int aligned_embed_dim,
1818 int aligned_intermediate_dim)
1820 if (!input || !w1 || !w2 || !fc1_out || !swiglu_out || !output) {
1826 if ((aligned_embed_dim %
QK_K) != 0 || (aligned_intermediate_dim %
QK_K) != 0) {
1830 const int up_dim = 2 * aligned_intermediate_dim;
1831 const int q8_blocks_embed = aligned_embed_dim /
QK_K;
1832 const int q8_blocks_inter = aligned_intermediate_dim /
QK_K;
1833 const int q8_blocks_max = (q8_blocks_embed > q8_blocks_inter) ? q8_blocks_embed : q8_blocks_inter;
1836 for (
int t = 0; t < tokens; ++t) {
1837 const float *input_row = input + (size_t)t * (
size_t)aligned_embed_dim;
1838 float *fc1_row = fc1_out + (size_t)t * (
size_t)up_dim;
1842 1, up_dim, aligned_embed_dim);
1845 swiglu_forward(fc1_out, swiglu_out, tokens, aligned_intermediate_dim);
1847 for (
int t = 0; t < tokens; ++t) {
1848 const float *swiglu_row = swiglu_out + (size_t)t * (
size_t)aligned_intermediate_dim;
1849 float *out_row = output + (size_t)t * (
size_t)aligned_embed_dim;
1853 1, aligned_embed_dim, aligned_intermediate_dim);
1858 const void *wq,
const float *bq,
CKDataType wq_dtype,
1859 const void *wk,
const float *bk,
CKDataType wk_dtype,
1860 const void *wv,
const float *bv,
CKDataType wv_dtype,
1864 int aligned_embed_dim,
1867 int aligned_head_dim)
1869 if (!input_row || !wq || !wk || !wv || !q_token || !k_token || !v_token) {
1873 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
1878 const uint8_t *wq_bytes = (
const uint8_t *)wq;
1879 const uint8_t *wk_bytes = (
const uint8_t *)wk;
1880 const uint8_t *wv_bytes = (
const uint8_t *)wv;
1882 for (
int h = 0; h < num_heads; ++h) {
1884 ? (
const void *)((
const float *)wq + (
size_t)h * head_w_elems)
1885 : (
const void *)(wq_bytes + (size_t)h * wq_head_bytes);
1886 const float *bq_h = bq ? (bq + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1887 float *q_h = q_token + (size_t)h * (
size_t)aligned_head_dim;
1889 1, aligned_head_dim, aligned_embed_dim, wq_dtype);
1892 for (
int h = 0; h < num_kv_heads; ++h) {
1894 ? (
const void *)((
const float *)wk + (
size_t)h * head_w_elems)
1895 : (
const void *)(wk_bytes + (size_t)h * wk_head_bytes);
1897 ? (
const void *)((
const float *)wv + (
size_t)h * head_w_elems)
1898 : (
const void *)(wv_bytes + (size_t)h * wv_head_bytes);
1899 const float *bk_h = bk ? (bk + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1900 const float *bv_h = bv ? (bv + (size_t)h * (
size_t)aligned_head_dim) : NULL;
1901 float *k_h = k_token + (size_t)h * (
size_t)aligned_head_dim;
1902 float *v_h = v_token + (size_t)h * (
size_t)aligned_head_dim;
1904 1, aligned_head_dim, aligned_embed_dim, wk_dtype);
1906 1, aligned_head_dim, aligned_embed_dim, wv_dtype);
1929 if ((aligned_D %
QK_K) == 0 && (aligned_intermediate %
QK_K) == 0) {
2013 aligned_intermediate);
2125 !p->
wq || !p->
wk || !p->
wv || !p->
wo || !p->
w1 || !p->
w2 ||
2130 if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
2141 const int K_concat = H * ad;
2144 const size_t token_slot = 0;
2145 const float *input_row = p->
input + token_slot * (size_t)aligned_D;
2146 float *ln1_row = p->
ln1_out + token_slot * (size_t)aligned_D;
2147 float *ln2_row = p->
ln2_out + token_slot * (size_t)aligned_D;
2148 float *proj_row = p->
proj_tmp + token_slot * (size_t)aligned_D;
2149 float *residual_row = p->
residual1 + token_slot * (size_t)aligned_D;
2150 float *mlp_row = p->
mlp_out + token_slot * (size_t)aligned_D;
2151 float *out_row = p->
output + token_slot * (size_t)aligned_D;
2153 float ln1_rstd_tmp = 0.0f;
2154 float ln2_rstd_tmp = 0.0f;
2155 float *ln1_rstd = p->
ln1_rstd ? (p->
ln1_rstd + token_slot) : &ln1_rstd_tmp;
2156 float *ln2_rstd = p->
ln2_rstd ? (p->
ln2_rstd + token_slot) : &ln2_rstd_tmp;
2159 size_t q_elems = (size_t)H * (
size_t)ad;
2160 size_t kv_elems = (size_t)H_kv * (
size_t)ad;
2161 float q_token[q_elems];
2162 float k_token[kv_elems];
2163 float v_token[kv_elems];
2164 float attn_token[q_elems];
2179 if ((aligned_D %
QK_K) == 0 && (aligned_intermediate %
QK_K) == 0) {
2180 const int q8_blocks_embed = aligned_D /
QK_K;
2181 const int q8_blocks_inter = aligned_intermediate /
QK_K;
2182 const int q8_blocks_max = (q8_blocks_embed > q8_blocks_inter) ? q8_blocks_embed : q8_blocks_inter;
2193 q_token, k_token, v_token,
2251 for (
int j = D; j < aligned_D; ++j) {
2272 int up_dim = 2 * aligned_intermediate;
2273 float *fc1_row = p->
fc1_out + token_slot * (size_t)up_dim;
2274 float *swiglu_row = p->
swiglu_out + token_slot * (size_t)aligned_intermediate;
2285 aligned_intermediate);
2304 q_token, k_token, v_token,
2356 for (
int j = D; j < aligned_D; ++j) {
2377 int up_dim = 2 * aligned_intermediate;
2378 float *fc1_row = p->
fc1_out + token_slot * (size_t)up_dim;
2379 float *swiglu_row = p->
swiglu_out + token_slot * (size_t)aligned_intermediate;
2391 aligned_intermediate);
2520 !p->
wq || !p->
wk || !p->
wv || !p->
wo || !p->
w1 || !p->
w2 ||
2525 if (token_index < 0 || cache_capacity <= 0 || token_index >= cache_capacity) {
2536 const int K_concat = H * ad;
2539 const size_t token_slot = 0;
2540 const float *input_row = p->
input + token_slot * (size_t)aligned_D;
2541 float *ln1_row = p->
ln1_out + token_slot * (size_t)aligned_D;
2542 float *ln2_row = p->
ln2_out + token_slot * (size_t)aligned_D;
2543 float *proj_row = p->
proj_tmp + token_slot * (size_t)aligned_D;
2544 float *residual_row = p->
residual1 + token_slot * (size_t)aligned_D;
2545 float *mlp_row = p->
mlp_out + token_slot * (size_t)aligned_D;
2546 float *out_row = p->
output + token_slot * (size_t)aligned_D;
2548 float ln1_rstd_tmp = 0.0f;
2549 float ln2_rstd_tmp = 0.0f;
2550 float *ln1_rstd = p->
ln1_rstd ? (p->
ln1_rstd + token_slot) : &ln1_rstd_tmp;
2551 float *ln2_rstd = p->
ln2_rstd ? (p->
ln2_rstd + token_slot) : &ln2_rstd_tmp;
2553 size_t q_elems = (size_t)H * (
size_t)ad;
2554 size_t kv_elems = (size_t)H_kv * (
size_t)ad;
2555 float q_token[q_elems];
2556 float k_token[kv_elems];
2557 float v_token[kv_elems];
2558 float attn_token[q_elems];
2573 q_token, k_token, v_token,
2615 (
const float *)p->
wo,
2632 for (
int j = D; j < aligned_D; ++j) {
2652 int up_dim = 2 * aligned_intermediate;
2653 float *fc1_row = p->
fc1_out + token_slot * (size_t)up_dim;
2654 float *swiglu_row = p->
swiglu_out + token_slot * (size_t)aligned_intermediate;
2668 aligned_intermediate);
2687 int up_dim = 2 * aligned_intermediate;
2688 int num_threads = 1;
2701 aligned_intermediate,
CKDataType
Supported data types in C-Kernel-Engine.
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void gemm_nt_q4_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void gemm_naive_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void swiglu_forward(const float *input, float *output, int tokens, int dim)
void gemm_swiglu_fused(const float *x, const float *W_gate, const float *W_up, const float *b_gate, const float *b_up, float *output, int M, int N, int K)
void swiglu_backward(const float *input, const float *d_output, float *d_input, int tokens, int dim)
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void attention_forward_causal_head_major_gqa(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
void rope_backward_qk(const float *d_q_out, const float *d_k_out, float *d_q, float *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void gemm_nt_q4_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q4_1 weights: C = A @ B^T.
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
void gemm_nt_q5_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void fc1_backward_kernel(const float *d_output, const float *fc1_input, const float *W_fc1, float *d_input, float *d_W_fc1, float *d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads)
void gemm_nt_q6_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q8_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void fc2_backward_kernel(const float *d_output, const float *fc2_input, const float *W_fc2, float *d_input, float *d_W_fc2, float *d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads)
void quantize_row_q8_k(const float *x, void *y, int k)
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
void gemm_nt_q5_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_1 weights: C = A @ B^T.
void attention_forward_causal_head_major_gqa_flash(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
void fused_mlp_swiglu_decode_v2(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
void attention_backward_causal_head_major_gqa(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
int ck_strict_parity_enabled(void)
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
static void ck_add_inplace(float *dst, const float *src, int tokens, int aligned_embed_dim)
static void ck_qkv_project_head_major_quant(const float *input, const void *wq, const float *bq, CKDataType wq_dtype, const void *wk, const float *bk, CKDataType wk_dtype, const void *wv, const float *bv, CKDataType wv_dtype, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_attention_project_head_major_quant(const float *attn_out, const void *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, CKDataType wo_dtype)
void ck_layer_backward_rmsnorm_swiglu(const CKLayerBackwardParams *p)
void ck_mlp_swiglu_forward_fused_token(const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *swiglu_row, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_attention_project_head_major(const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_mlp_swiglu_forward(const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_attention_project_head_major_ref(const float *attn_out, const float *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_residual_add_backward(const float *d_out, float *d_a, float *d_b, int tokens, int aligned_embed_dim)
static void ck_mlp_swiglu_forward_quant(const float *input, const void *w1, const float *b1, CKDataType w1_dtype, const void *w2, const float *b2, CKDataType w2_dtype, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_qkv_project_head_major_token_q4_k_q8_k(const block_q8_K *input_q8, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_attention_project_head_major_q4_k(const float *attn_out, const void *wo, const float *bo, float *out, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void ck_qkv_project_head_major_q4_k_q8_k(const float *input, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_qkv_project_head_major_ref(const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static int ck_q8k_activations_enabled(void)
void ck_layer_forward_rmsnorm_swiglu_quant(const CKLayerForwardParamsQ4K *p)
static int ck_layer_debug_enabled(void)
void ck_mlp_swiglu_forward_fully_fused_token(const float *input_row, const float *w1, const float *b1, const float *w2, const float *b2, float *output_row, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_mlp_swiglu_forward_q4_k_q8_k(const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_layer_forward_rmsnorm_swiglu_decode(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_layer_forward_rmsnorm_swiglu_q4_k(const CKLayerForwardParamsQ4K *p)
void ck_qkv_project_head_major(const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_debug_check_q8k(const char *stage, const void *q8_buf, int num_blocks)
static void ck_attention_project_head_major_q4_k_q8_k(const float *attn_out, const void *wo, const float *bo, float *out, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void ck_mlp_swiglu_forward_q4_k(const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_gemm_nt_quant(const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype)
void ck_attention_flash_decode_wrapper(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
Wrapper to call TRUE flash attention from orchestration layer.
void ck_qkv_project_head_major_backward(const float *d_q, const float *d_k, const float *d_v, const float *input, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *d_input, float *d_wq, float *d_bq, float *d_wk, float *d_bk, float *d_wv, float *d_bv, float *scratch, int tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim, int num_threads)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
void ck_layer_forward_rmsnorm_swiglu(const CKLayerForwardParams *p)
void ck_layer_forward_rmsnorm_swiglu_decode_fused(const CKLayerForwardParams *p, int token_index, int cache_capacity)
void ck_layer_forward_rmsnorm_swiglu_decode_q4_k(const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity)
static void ck_debug_check_q4k_weights(const char *stage, const void *q4_buf, int num_blocks)
static void ck_debug_check_buffer(const char *stage, const float *buf, int size)
static void ck_mlp_swiglu_forward_ref(const float *input, const float *w1, const float *b1, const float *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
void ck_layer_forward_rmsnorm_swiglu_ref(const CKLayerForwardParams *p)
static void ck_qkv_project_head_major_token_q4_k(const float *input_row, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_qkv_project_head_major_q4_k(const float *input, const void *wq, const float *bq, const void *wk, const float *bk, const void *wv, const float *bv, float *q, float *k, float *v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
static void ck_mlp_swiglu_forward_q4_k_q8_k_prefill(const float *input, const void *w1, const float *b1, const void *w2, const float *b2, float *fc1_out, float *swiglu_out, float *output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
static void ck_qkv_project_head_major_token_quant(const float *input_row, const void *wq, const float *bq, CKDataType wq_dtype, const void *wk, const float *bk, CKDataType wk_dtype, const void *wv, const float *bv, CKDataType wv_dtype, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_attention_project_head_major_backward(const float *d_out, const float *attn_out, const float *wo, float *d_attn_out, float *d_wo, float *d_bo, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
void ck_layer_forward_rmsnorm_swiglu_decode_quant(const CKLayerForwardParamsQ4K *p, int token_index, int cache_capacity)
void ck_qkv_project_head_major_token(const float *input_row, const float *wq, const float *bq, const float *wk, const float *bk, const float *wv, const float *bv, float *q_token, float *k_token, float *v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
void ck_attention_project_head_major_decode_token(const float *attn_token, const float *wo, const float *bo, float *out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
int aligned_context_window
int aligned_intermediate_dim
int aligned_intermediate_dim
int aligned_context_window
int aligned_context_window
int aligned_intermediate_dim