48 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
49 #include <immintrin.h>
64 #define PREFILL_TILE_M 64
65 #define PREFILL_TILE_N 256
68 return (value + align - 1) & ~(align - 1);
72 #if defined(__AVX__) && !defined(__AVX512F__)
73 static inline float hsum256_prefill(__m256 v) {
74 __m128 lo = _mm256_castps256_ps128(v);
75 __m128 hi = _mm256_extractf128_ps(v, 1);
76 __m128 sum128 = _mm_add_ps(lo, hi);
77 sum128 = _mm_hadd_ps(sum128, sum128);
78 sum128 = _mm_hadd_ps(sum128, sum128);
79 return _mm_cvtss_f32(sum128);
91 int aligned_embed_dim,
94 for (
int t = 0; t < tile_m; ++t) {
95 const float *x = input + (size_t)t * (
size_t)aligned_embed_dim;
96 float *y = output + (size_t)t * (
size_t)aligned_embed_dim;
98 #if defined(__AVX512F__)
99 __m512 sum_sq_vec = _mm512_setzero_ps();
101 for (; d + 16 <= embed_dim; d += 16) {
102 __m512 xv = _mm512_loadu_ps(&x[d]);
103 sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
105 float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
106 for (; d < embed_dim; ++d) {
107 sum_sq += x[d] * x[d];
110 float rstd = 1.0f / sqrtf(sum_sq / (
float)embed_dim + eps);
111 __m512 rstd_vec = _mm512_set1_ps(rstd);
114 for (; d + 16 <= embed_dim; d += 16) {
115 __m512 xv = _mm512_loadu_ps(&x[d]);
116 __m512 gv = gamma ? _mm512_loadu_ps(&gamma[d]) : _mm512_set1_ps(1.0f);
117 __m512 yv = _mm512_mul_ps(_mm512_mul_ps(xv, rstd_vec), gv);
118 _mm512_storeu_ps(&y[d], yv);
120 for (; d < embed_dim; ++d) {
121 float g = gamma ? gamma[d] : 1.0f;
122 y[d] = x[d] * rstd * g;
125 #elif defined(__AVX__)
126 __m256 sum_sq_vec = _mm256_setzero_ps();
128 for (; d + 8 <= embed_dim; d += 8) {
129 __m256 xv = _mm256_loadu_ps(&x[d]);
130 sum_sq_vec = _mm256_add_ps(sum_sq_vec, _mm256_mul_ps(xv, xv));
132 float sum_sq = hsum256_prefill(sum_sq_vec);
133 for (; d < embed_dim; ++d) {
134 sum_sq += x[d] * x[d];
137 float rstd = 1.0f / sqrtf(sum_sq / (
float)embed_dim + eps);
138 __m256 rstd_vec = _mm256_set1_ps(rstd);
141 for (; d + 8 <= embed_dim; d += 8) {
142 __m256 xv = _mm256_loadu_ps(&x[d]);
143 __m256 gv = gamma ? _mm256_loadu_ps(&gamma[d]) : _mm256_set1_ps(1.0f);
144 __m256 yv = _mm256_mul_ps(_mm256_mul_ps(xv, rstd_vec), gv);
145 _mm256_storeu_ps(&y[d], yv);
147 for (; d < embed_dim; ++d) {
148 float g = gamma ? gamma[d] : 1.0f;
149 y[d] = x[d] * rstd * g;
153 for (
int d = 0; d < embed_dim; ++d) {
154 sum_sq += x[d] * x[d];
156 float rstd = 1.0f / sqrtf(sum_sq / (
float)embed_dim + eps);
157 for (
int d = 0; d < embed_dim; ++d) {
158 float g = gamma ? gamma[d] : 1.0f;
159 y[d] = x[d] * rstd * g;
163 for (
int d = embed_dim; d < aligned_embed_dim; ++d) {
249 if (C_stride == tile_n) {
250 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
252 1.0f, A, K, B_tile, K,
257 for (
int i = 0; i < tile_m; ++i) {
258 cblas_sgemv(CblasRowMajor, CblasNoTrans,
260 1.0f, B_tile, K, A + (
size_t)i * K, 1,
261 0.0f,
C + (
size_t)i * C_stride, 1);
265 #pragma omp parallel for schedule(static)
267 for (
int i = 0; i < tile_m; ++i) {
268 const float *a_row = A + (size_t)i * K;
269 float *c_row =
C + (size_t)i * C_stride;
271 for (
int j = 0; j < tile_n; ++j) {
272 const float *b_row = B_tile + (size_t)j * K;
275 #if defined(__AVX512F__)
276 __m512 acc = _mm512_setzero_ps();
278 for (; k + 16 <= K; k += 16) {
279 __m512 av = _mm512_loadu_ps(a_row + k);
280 __m512 bv = _mm512_loadu_ps(b_row + k);
281 acc = _mm512_fmadd_ps(av, bv, acc);
283 sum = _mm512_reduce_add_ps(acc);
285 sum += a_row[k] * b_row[k];
287 #elif defined(__AVX__)
288 __m256 acc = _mm256_setzero_ps();
290 for (; k + 8 <= K; k += 8) {
291 __m256 av = _mm256_loadu_ps(a_row + k);
292 __m256 bv = _mm256_loadu_ps(b_row + k);
293 acc = _mm256_add_ps(acc, _mm256_mul_ps(av, bv));
295 sum = hsum256_prefill(acc);
297 sum += a_row[k] * b_row[k];
300 for (
int k = 0; k < K; ++k) {
301 sum += a_row[k] * b_row[k];
318 for (
int i = 0; i < tile_m; ++i) {
319 float *row = out + (size_t)i * (
size_t)out_dim;
320 for (
int j = 0; j < out_dim; ++j) {
341 float *x_norm_scratch)
344 for (
int n_start = 0; n_start < out_dim; n_start +=
PREFILL_TILE_N) {
347 : (out_dim - n_start);
350 const float *W_tile = W + (size_t)n_start * hidden;
353 for (
int m_start = 0; m_start < seq_len; m_start +=
PREFILL_TILE_M) {
356 : (seq_len - m_start);
358 const float *x_tile = x + (size_t)m_start * hidden;
359 float *out_tile = output + (size_t)m_start * out_dim + n_start;
363 rmsnorm_tile(x_tile, gamma, x_norm_scratch, tile_m, hidden, hidden, eps);
367 rmsnorm_tile(x_tile, gamma, x_norm_scratch, tile_m, hidden, hidden, eps);
372 tile_m, tile_n, hidden, out_dim);
412 for (
int m_start = 0; m_start < seq_len; m_start +=
PREFILL_TILE_M) {
416 const float *x_tile = x + (size_t)m_start * hidden;
419 rmsnorm_tile(x_tile, gamma, scratch, tile_m, hidden, hidden, eps);
422 float *Q_tile = Q + (size_t)m_start * q_dim;
426 float *K_tile = K + (size_t)m_start * kv_dim;
430 float *V_tile = V + (size_t)m_start * kv_dim;
444 const float *Wq,
const float *Bq,
445 const float *Wk,
const float *Bk,
446 const float *Wv,
const float *Bv,
452 int aligned_embed_dim,
456 int aligned_head_dim,
457 int kv_stride_tokens,
461 if (!x || !gamma || !Wq || !Wk || !Wv || !Q || !K || !V || !scratch) {
464 if (seq_len <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
465 head_dim <= 0 || aligned_head_dim <= 0 ||
466 num_heads <= 0 || num_kv_heads <= 0) {
469 if (kv_stride_tokens < seq_len) {
473 const size_t q_head_stride = (size_t)seq_len * (
size_t)aligned_head_dim;
474 const size_t kv_head_stride = (size_t)kv_stride_tokens * (
size_t)aligned_head_dim;
475 const size_t head_w_stride = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
477 for (
int m_start = 0; m_start < seq_len; m_start +=
PREFILL_TILE_M) {
481 const float *x_tile = x + (size_t)m_start * (
size_t)aligned_embed_dim;
482 rmsnorm_tile(x_tile, gamma, scratch, tile_m, embed_dim, aligned_embed_dim, eps);
484 for (
int h = 0; h < num_heads; ++h) {
485 const float *wq_h = Wq + (size_t)h * head_w_stride;
486 const float *bq_h = Bq ? (Bq + (size_t)h * (
size_t)aligned_head_dim) : NULL;
487 float *q_h = Q + (size_t)h * q_head_stride + (
size_t)m_start * (size_t)aligned_head_dim;
490 tile_m, aligned_head_dim, aligned_embed_dim, aligned_head_dim);
494 for (
int h = 0; h < num_kv_heads; ++h) {
495 const float *wk_h = Wk + (size_t)h * head_w_stride;
496 const float *wv_h = Wv + (size_t)h * head_w_stride;
497 const float *bk_h = Bk ? (Bk + (size_t)h * (
size_t)aligned_head_dim) : NULL;
498 const float *bv_h = Bv ? (Bv + (size_t)h * (
size_t)aligned_head_dim) : NULL;
499 float *k_h = K + (size_t)h * kv_head_stride + (
size_t)m_start * (size_t)aligned_head_dim;
500 float *v_h = V + (size_t)h * kv_head_stride + (
size_t)m_start * (size_t)aligned_head_dim;
503 tile_m, aligned_head_dim, aligned_embed_dim, aligned_head_dim);
507 tile_m, aligned_head_dim, aligned_embed_dim, aligned_head_dim);
522 const void *Wq,
const float *Bq,
CKDataType wq_dt,
523 const void *Wk,
const float *Bk,
CKDataType wk_dt,
524 const void *Wv,
const float *Bv,
CKDataType wv_dt,
530 int aligned_embed_dim,
534 int aligned_head_dim,
535 int kv_stride_tokens,
539 if (!x || !gamma || !Wq || !Wk || !Wv || !Q || !K || !V || !scratch) {
542 if (seq_len <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
543 head_dim <= 0 || aligned_head_dim <= 0 ||
544 num_heads <= 0 || num_kv_heads <= 0) {
547 if (aligned_embed_dim % 32 != 0) {
550 if (kv_stride_tokens < seq_len) {
559 if (!use_q8_k_path && !use_q8_0_path) {
575 const size_t float_bytes = (size_t)
PREFILL_TILE_M * (
size_t)aligned_embed_dim *
sizeof(float);
578 const size_t q8_row_bytes =
ck_dtype_row_bytes(act_quant_type, (
size_t)aligned_embed_dim);
582 float *normed = (
float *)scratch;
583 uint8_t *q8_tile = (uint8_t *)scratch + q8_offset;
586 const size_t q_head_stride = (size_t)seq_len * (
size_t)aligned_head_dim;
587 const size_t kv_head_stride = (size_t)kv_stride_tokens * (
size_t)aligned_head_dim;
588 const size_t head_w_elems = (size_t)aligned_head_dim * (
size_t)aligned_embed_dim;
593 for (
int m_start = 0; m_start < seq_len; m_start +=
PREFILL_TILE_M) {
597 const float *x_tile = x + (size_t)m_start * (
size_t)aligned_embed_dim;
598 rmsnorm_tile(x_tile, gamma, normed, tile_m, embed_dim, aligned_embed_dim, eps);
601 for (
int t = 0; t < tile_m; ++t) {
602 const float *row = normed + (size_t)t * (
size_t)aligned_embed_dim;
605 q8_tile + (
size_t)t * q8_row_bytes,
609 q8_tile + (
size_t)t * q8_row_bytes,
614 for (
int h = 0; h < num_heads; ++h) {
615 const uint8_t *wq_h = (
const uint8_t *)Wq + (
size_t)h * wq_head_bytes;
616 const float *bq_h = Bq ? (Bq + (size_t)h * (
size_t)aligned_head_dim) : NULL;
617 float *q_h = Q + (size_t)h * q_head_stride + (
size_t)m_start * (size_t)aligned_head_dim;
621 tile_m, aligned_head_dim, aligned_embed_dim, wq_dt);
624 tile_m, aligned_head_dim, aligned_embed_dim, wq_dt);
628 for (
int h = 0; h < num_kv_heads; ++h) {
629 const uint8_t *wk_h = (
const uint8_t *)Wk + (
size_t)h * wk_head_bytes;
630 const uint8_t *wv_h = (
const uint8_t *)Wv + (
size_t)h * wv_head_bytes;
631 const float *bk_h = Bk ? (Bk + (size_t)h * (
size_t)aligned_head_dim) : NULL;
632 const float *bv_h = Bv ? (Bv + (size_t)h * (
size_t)aligned_head_dim) : NULL;
633 float *k_h = K + (size_t)h * kv_head_stride + (
size_t)m_start * (size_t)aligned_head_dim;
634 float *v_h = V + (size_t)h * kv_head_stride + (
size_t)m_start * (size_t)aligned_head_dim;
638 tile_m, aligned_head_dim, aligned_embed_dim, wk_dt);
640 tile_m, aligned_head_dim, aligned_embed_dim, wv_dt);
643 tile_m, aligned_head_dim, aligned_embed_dim, wk_dt);
645 tile_m, aligned_head_dim, aligned_embed_dim, wv_dt);
652 if (aligned_embed_dim <= 0) {
655 const size_t float_bytes = (size_t)
PREFILL_TILE_M * (
size_t)aligned_embed_dim *
sizeof(float);
659 const size_t q8_row_bytes = (q8_k_row_bytes > q8_0_row_bytes) ? q8_k_row_bytes : q8_0_row_bytes;
684 rmsnorm_tile(x, gamma, x_norm, seq_len, hidden, hidden, eps);
688 for (
int n_start = 0; n_start < q_dim; n_start +=
PREFILL_TILE_N) {
691 const float *W_tile = Wq + (size_t)n_start * hidden;
693 for (
int m_start = 0; m_start < seq_len; m_start +=
PREFILL_TILE_M) {
696 const float *x_tile = x_norm + (size_t)m_start * hidden;
697 float *out_tile = Q + (size_t)m_start * q_dim + n_start;
699 tile_m, tile_n, hidden, q_dim);
704 for (
int n_start = 0; n_start < kv_dim; n_start +=
PREFILL_TILE_N) {
707 const float *W_tile = Wk + (size_t)n_start * hidden;
709 for (
int m_start = 0; m_start < seq_len; m_start +=
PREFILL_TILE_M) {
712 const float *x_tile = x_norm + (size_t)m_start * hidden;
713 float *out_tile = K + (size_t)m_start * kv_dim + n_start;
715 tile_m, tile_n, hidden, kv_dim);
720 for (
int n_start = 0; n_start < kv_dim; n_start +=
PREFILL_TILE_N) {
723 const float *W_tile = Wv + (size_t)n_start * hidden;
725 for (
int m_start = 0; m_start < seq_len; m_start +=
PREFILL_TILE_M) {
728 const float *x_tile = x_norm + (size_t)m_start * hidden;
729 float *out_tile = V + (size_t)m_start * kv_dim + n_start;
731 tile_m, tile_n, hidden, kv_dim);
777 const int TILE_N_INTER = 512;
778 float *gate_tile = scratch;
780 float *hidden_tile = gate_tile;
783 for (
int inter_start = 0; inter_start < intermediate; inter_start += TILE_N_INTER) {
784 int tile_inter = (inter_start + TILE_N_INTER <= intermediate)
785 ? TILE_N_INTER : (intermediate - inter_start);
787 const float *W_gate_tile = W_gate + (size_t)inter_start * hidden;
788 const float *W_up_tile = W_up + (size_t)inter_start * hidden;
791 for (
int m_start = 0; m_start < seq_len; m_start +=
PREFILL_TILE_M) {
795 const float *x_tile = x + (size_t)m_start * hidden;
799 tile_m, tile_inter, hidden, tile_inter);
801 tile_m, tile_inter, hidden, tile_inter);
803 add_bias_tile(gate_tile, B_gate + inter_start, tile_m, tile_inter);
806 add_bias_tile(up_tile, B_up + inter_start, tile_m, tile_inter);
810 for (
int i = 0; i < tile_m; ++i) {
811 float *g = gate_tile + (size_t)i * tile_inter;
812 float *u = up_tile + (size_t)i * tile_inter;
813 for (
int j = 0; j < tile_inter; ++j) {
815 float silu = gv / (1.0f + expf(-gv));
823 const float *W_down_slice = W_down + (size_t)inter_start;
824 float *out_tile = output + (size_t)m_start * hidden;
833 for (
int i = 0; i < tile_m; ++i) {
834 float *h = hidden_tile + (size_t)i * tile_inter;
835 float *o = out_tile + (size_t)i * hidden;
837 for (
int d = 0; d < hidden; ++d) {
838 const float *w_row = W_down + (size_t)d * intermediate + inter_start;
839 float sum = (inter_start == 0)
840 ? (B_down ? B_down[d] : 0.0f)
843 #if defined(__AVX512F__)
844 __m512 acc = _mm512_setzero_ps();
846 for (; j + 16 <= tile_inter; j += 16) {
847 __m512 hv = _mm512_loadu_ps(h + j);
848 __m512 wv = _mm512_loadu_ps(w_row + j);
849 acc = _mm512_fmadd_ps(hv, wv, acc);
851 sum += _mm512_reduce_add_ps(acc);
852 for (; j < tile_inter; ++j) {
853 sum += h[j] * w_row[j];
855 #elif defined(__AVX__)
856 __m256 acc = _mm256_setzero_ps();
858 for (; j + 8 <= tile_inter; j += 8) {
859 __m256 hv = _mm256_loadu_ps(h + j);
860 __m256 wv = _mm256_loadu_ps(w_row + j);
861 acc = _mm256_add_ps(acc, _mm256_mul_ps(hv, wv));
863 sum += hsum256_prefill(acc);
864 for (; j < tile_inter; ++j) {
865 sum += h[j] * w_row[j];
868 for (
int j = 0; j < tile_inter; ++j) {
869 sum += h[j] * w_row[j];
892 output, seq_len, hidden, intermediate,
900 const int TILE_N_INTER = 512;
906 return x / (1.0f + expf(-x));
976 int aligned_embed_dim,
977 int intermediate_dim,
978 int aligned_intermediate_dim,
981 if (!x || !W1 || !W2 || !output || !scratch) {
984 if (seq_len <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
985 intermediate_dim <= 0 || aligned_intermediate_dim <= 0) {
988 if (aligned_embed_dim < embed_dim || aligned_intermediate_dim < intermediate_dim) {
991 if ((aligned_embed_dim % 32) != 0 || (aligned_intermediate_dim % 256) != 0) {
999 const int inter = aligned_intermediate_dim;
1004 uint8_t *scratch_bytes = (uint8_t *)scratch;
1005 size_t q8_bytes = (size_t)tile_m_max * q8_row_bytes;
1006 size_t gate_bytes = (size_t)tile_m_max * (
size_t)inter *
sizeof(float);
1007 size_t up_bytes = gate_bytes;
1009 size_t up_offset = gate_offset +
align_up_size(gate_bytes, 64);
1010 size_t q8k_offset = up_offset +
align_up_size(up_bytes, 64);
1012 uint8_t *q8_tile = scratch_bytes;
1013 float *gate_tile = (
float *)(scratch_bytes + gate_offset);
1014 float *up_tile = (
float *)(scratch_bytes + up_offset);
1015 uint8_t *q8k_tile = scratch_bytes + q8k_offset;
1017 const uint8_t *w1_base = (
const uint8_t *)W1;
1018 const uint8_t *w_gate = w1_base;
1019 const uint8_t *w_up = w1_base + (size_t)inter * w1_row_bytes;
1021 const float *b_gate = B1;
1022 const float *b_up = B1 ? (B1 + (size_t)inter) : NULL;
1024 for (
int m_start = 0; m_start < seq_len; m_start += tile_m_max) {
1025 int tile_m = (m_start + tile_m_max <= seq_len)
1026 ? tile_m_max : (seq_len - m_start);
1028 const float *x_tile = x + (size_t)m_start * (
size_t)aligned_embed_dim;
1029 float *out_tile = output + (size_t)m_start * (
size_t)aligned_embed_dim;
1031 for (
int t = 0; t < tile_m; ++t) {
1032 const float *row = x_tile + (size_t)t * (
size_t)aligned_embed_dim;
1034 q8_tile + (
size_t)t * q8_row_bytes,
1039 tile_m, inter, aligned_embed_dim, w1_dt);
1041 tile_m, inter, aligned_embed_dim, w1_dt);
1043 for (
int i = 0; i < tile_m; ++i) {
1044 float *g = gate_tile + (size_t)i * (
size_t)inter;
1045 float *u = up_tile + (size_t)i * (
size_t)inter;
1046 for (
int j = 0; j < inter; ++j) {
1051 for (
int i = 0; i < tile_m; ++i) {
1052 const float *row = gate_tile + (size_t)i * (
size_t)inter;
1054 q8k_tile + (
size_t)i * q8k_row_bytes,
1055 aligned_intermediate_dim);
1059 tile_m, aligned_embed_dim, aligned_intermediate_dim, w2_dt);
1064 int aligned_intermediate_dim)
1066 if (aligned_embed_dim <= 0 || aligned_intermediate_dim <= 0) {
1072 const size_t gate_bytes = (size_t)
PREFILL_TILE_M * (
size_t)aligned_intermediate_dim *
sizeof(float);
1073 const size_t up_bytes = gate_bytes;
CKDataType
Supported data types in C-Kernel-Engine.
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void quantize_row_q8_k(const float *x, void *y, int k)
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void gemm_nt_q8_0_q8_0(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
void gemm_nt_q6_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.
Quantization block structures for weight-only quantization.
void gemm_nt_q5_0_q8_0(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.
static void gemm_nt_q8_0_mlp_dispatch(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
static void gemm_nt_q8_0_dispatch(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
static int qkv_q8_k_dtype_supported(CKDataType dt)
static size_t align_up_size(size_t value, size_t align)
static int mlp_q8_k_dtype_supported(CKDataType dt)
void fused_mlp_swiglu_prefill(const float *x, const float *W_gate, const float *W_up, const float *W_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
Fused MLP (Gate + Up + SwiGLU + Down) for prefill.
static float silu_prefill(float x)
static void fused_rmsnorm_gemm_2d_tiled(const float *x, const float *gamma, const float *W, float *output, int seq_len, int hidden, int out_dim, float eps, float *x_norm_scratch)
Fused RMSNorm + single GEMM with 2D tiling (weight reuse)
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(int aligned_embed_dim, int aligned_intermediate_dim)
Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.
void fused_rmsnorm_qkv_prefill_head_major(const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (head-major outputs)
static void gemm_nt_q8_k_qkv_dispatch(const void *A_q8k, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
static void add_bias_tile(float *out, const float *bias, int tile_m, int out_dim)
void fused_rmsnorm_qkv_prefill(const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (v3 optimized)
void unfused_rmsnorm_qkv_prefill(const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *x_norm, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps)
Unfused version for comparison.
static int qkv_q8_0_dtype_supported(CKDataType dt)
static int mlp_q8_0_dtype_supported(CKDataType dt)
static void gemm_tile_nt_strided(const float *A, const float *B_tile, float *C, int tile_m, int tile_n, int K, int C_stride)
GEMM tile with N-dimension tiling (weight reuse)
size_t fused_rmsnorm_qkv_scratch_size(int hidden)
Get scratch size for fused prefill.
static void gemm_nt_q8_k_mlp_dispatch(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
void fused_mlp_swiglu_prefill_bias(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *B_gate, const float *B_up, const float *B_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
Fused MLP for prefill with proper tiling.
static void rmsnorm_tile(const float *input, const float *gamma, float *output, int tile_m, int embed_dim, int aligned_embed_dim, float eps)
Compute RMSNorm for a tile of tokens.
void fused_mlp_swiglu_prefill_w1w2_quant(const float *x, const void *W1, const float *B1, CKDataType w1_dt, const void *W2, const float *B2, CKDataType w2_dt, float *output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void *scratch)
Quantized fused MLP for prefill (W1=gate+up, W2=down)
size_t fused_mlp_swiglu_scratch_size(int intermediate)
Get scratch size for fused MLP.
void fused_rmsnorm_qkv_prefill_head_major_quant(const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
static void silu(float *x, int n)