26 return (value + align - 1) & ~(align - 1);
36 (
size_t)aligned_head_dim);
37 const size_t head_stride = (size_t)tokens * (
size_t)aligned_head_dim;
38 for (
int h = 0; h < num_heads; ++h) {
39 const float *head = attn_out + (size_t)h * head_stride;
40 for (
int t = 0; t < tokens; ++t) {
41 const float *row = head + (size_t)t * (
size_t)aligned_head_dim;
42 uint8_t *out = dst + ((size_t)h * (
size_t)tokens + (size_t)t) *
54 int aligned_embed_dim,
59 (
size_t)aligned_head_dim);
60 const int blocks_per_head = aligned_head_dim /
QK8_0;
61 const int blocks_per_row = aligned_embed_dim /
QK8_0;
64 for (
int t = 0; t < tokens; ++t) {
65 float *out_row = output + (size_t)t * (
size_t)aligned_embed_dim;
66 for (
int n = 0; n < aligned_embed_dim; ++n) {
67 float sum = bias ? bias[n] : 0.0f;
68 const block_q8_0 *w_row = weights + (size_t)n * (
size_t)blocks_per_row;
70 for (
int h = 0; h < num_heads; ++h) {
71 const uint8_t *a_row = attn_q8 +
72 ((size_t)h * (
size_t)tokens + (size_t)t) *
74 const block_q8_0 *w_head = w_row + (size_t)h * (
size_t)blocks_per_head;
85 int aligned_embed_dim,
89 if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 || aligned_head_dim <= 0) {
93 const size_t q_bytes = (size_t)num_heads * (
size_t)tokens *
94 (size_t)aligned_head_dim *
sizeof(
float);
95 const size_t attn_bytes = q_bytes;
96 const size_t proj_bytes = (size_t)tokens * (
size_t)aligned_embed_dim *
sizeof(float);
108 const float *residual,
109 const float *ln1_gamma,
110 const void *wq,
const float *bq,
CKDataType wq_dt,
111 const void *wk,
const float *bk,
CKDataType wk_dt,
112 const void *wv,
const float *bv,
CKDataType wv_dt,
113 const void *wo,
const float *bo,
CKDataType wo_dt,
116 const float *rope_cos,
117 const float *rope_sin,
122 int aligned_embed_dim,
126 int aligned_head_dim,
130 if (!output || !input || !ln1_gamma || !wq || !wk || !wv || !wo ||
131 !kv_cache_k || !kv_cache_v || !scratch) {
134 if (tokens <= 0 || cache_capacity <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
135 head_dim <= 0 || aligned_head_dim <= 0 || num_heads <= 0 || num_kv_heads <= 0) {
138 if (aligned_embed_dim < embed_dim || aligned_head_dim < head_dim) {
141 if (start_pos < 0 || start_pos + tokens > cache_capacity) {
147 if ((aligned_head_dim %
QK8_0) != 0 || (aligned_embed_dim %
QK8_0) != 0) {
151 const size_t q_bytes = (size_t)num_heads * (
size_t)tokens *
152 (size_t)aligned_head_dim *
sizeof(
float);
153 const size_t attn_bytes = q_bytes;
154 const size_t proj_bytes = (size_t)tokens * (
size_t)aligned_embed_dim *
sizeof(float);
157 uint8_t *scratch_bytes = (uint8_t *)scratch;
158 float *q = (
float *)scratch_bytes;
160 float *attn_out = (
float *)scratch_bytes;
162 float *proj_scratch = (
float *)scratch_bytes;
164 void *qkv_scratch = (
void *)scratch_bytes;
165 (void)qkv_scratch_bytes;
168 float *k_ptr = kv_cache_k + (size_t)start_pos * (
size_t)aligned_head_dim;
169 float *v_ptr = kv_cache_v + (size_t)start_pos * (
size_t)aligned_head_dim;
174 (
const float *)wq, bq,
175 (
const float *)wk, bk,
176 (
const float *)wv, bv,
211 if (rope_cos && rope_sin) {
226 if (start_pos == 0) {
238 const float scale = 1.0f / sqrtf((
float)head_dim);
239 const size_t q_head_stride = (size_t)tokens * (
size_t)aligned_head_dim;
240 const size_t kv_head_stride = (size_t)cache_capacity * (
size_t)aligned_head_dim;
242 for (
int h = 0; h < num_heads; ++h) {
243 int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
244 const float *k_head = kv_cache_k + (size_t)kv_head * kv_head_stride;
245 const float *v_head = kv_cache_v + (size_t)kv_head * kv_head_stride;
247 for (
int i = 0; i < tokens; ++i) {
248 const float *q_vec = q + (size_t)h * q_head_stride + (
size_t)i * (size_t)aligned_head_dim;
249 float *out_vec = attn_out + (size_t)h * q_head_stride + (
size_t)i * (size_t)aligned_head_dim;
263 if ((num_heads * aligned_head_dim) != aligned_embed_dim) {
269 uint8_t *attn_q8 = (uint8_t *)q;
CKDataType
Supported data types in C-Kernel-Engine.
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void fused_rmsnorm_qkv_prefill_head_major(const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (head-major outputs)
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void fused_rmsnorm_qkv_prefill_head_major_quant(const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
Quantization block structures for weight-only quantization.
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.
size_t mega_fused_attention_prefill_q8_0_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Get scratch buffer size for mega_fused_attention_prefill_q8_0.
static void out_proj_head_major_q8_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static size_t align_up_size(size_t value, size_t align)
static void quantize_attn_out_head_major_q8_0(const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)
void mega_fused_attention_prefill_q8_0(float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch)
Mega-fused prefill attention kernel (Q8_0 out-proj)