40 return (value + align - 1) & ~(align - 1);
46 int aligned_embed_dim,
50 const size_t head_in_stride = (size_t)tokens * (
size_t)aligned_head_dim;
51 for (
int t = 0; t < tokens; ++t) {
52 float *out_row = dst + (size_t)t * (
size_t)aligned_embed_dim;
53 for (
int h = 0; h < num_heads; ++h) {
54 const float *src = attn_out + (size_t)h * head_in_stride +
55 (
size_t)t * (size_t)aligned_head_dim;
56 memcpy(out_row + (
size_t)h * (
size_t)aligned_head_dim,
58 (
size_t)aligned_head_dim *
sizeof(
float));
65 static int cached = -2;
70 const char *env = getenv(
"CK_Q8_0_OUTPROJ");
71 if (!env || !env[0]) {
75 if (env[0] ==
'0' || env[0] ==
'n' || env[0] ==
'N' ||
76 env[0] ==
'f' || env[0] ==
'F') {
91 (
size_t)aligned_head_dim);
92 const size_t head_stride = (size_t)tokens * (
size_t)aligned_head_dim;
93 for (
int h = 0; h < num_heads; ++h) {
94 const float *head = attn_out + (size_t)h * head_stride;
95 for (
int t = 0; t < tokens; ++t) {
96 const float *row = head + (size_t)t * (
size_t)aligned_head_dim;
97 uint8_t *out = dst + ((size_t)h * (
size_t)tokens + (size_t)t) *
109 int aligned_embed_dim,
111 int aligned_head_dim)
114 (
size_t)aligned_head_dim);
115 const int blocks_per_head = aligned_head_dim /
QK5_0;
116 const int blocks_per_row = aligned_embed_dim /
QK5_0;
119 for (
int t = 0; t < tokens; ++t) {
120 float *out_row = output + (size_t)t * (
size_t)aligned_embed_dim;
121 for (
int n = 0; n < aligned_embed_dim; ++n) {
122 float sum = bias ? bias[n] : 0.0f;
123 const block_q5_0 *w_row = weights + (size_t)n * (
size_t)blocks_per_row;
125 for (
int h = 0; h < num_heads; ++h) {
126 const uint8_t *a_row = attn_q8 +
127 ((size_t)h * (
size_t)tokens + (size_t)t) *
129 const block_q5_0 *w_head = w_row + (size_t)h * (
size_t)blocks_per_head;
130 float partial = 0.0f;
140 int aligned_embed_dim,
142 int aligned_head_dim)
144 if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 || aligned_head_dim <= 0) {
148 const size_t q_bytes = (size_t)num_heads * (
size_t)tokens *
149 (size_t)aligned_head_dim *
sizeof(
float);
150 const size_t attn_bytes = q_bytes;
151 const size_t proj_bytes = (size_t)tokens * (
size_t)aligned_embed_dim *
sizeof(float);
163 const float *residual,
164 const float *ln1_gamma,
165 const void *wq,
const float *bq,
CKDataType wq_dt,
166 const void *wk,
const float *bk,
CKDataType wk_dt,
167 const void *wv,
const float *bv,
CKDataType wv_dt,
168 const void *wo,
const float *bo,
CKDataType wo_dt,
171 const float *rope_cos,
172 const float *rope_sin,
177 int aligned_embed_dim,
181 int aligned_head_dim,
185 if (!output || !input || !ln1_gamma || !wq || !wk || !wv || !wo ||
186 !kv_cache_k || !kv_cache_v || !scratch) {
189 if (tokens <= 0 || cache_capacity <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
190 head_dim <= 0 || aligned_head_dim <= 0 || num_heads <= 0 || num_kv_heads <= 0) {
193 if (aligned_embed_dim < embed_dim || aligned_head_dim < head_dim) {
196 if (start_pos < 0 || start_pos + tokens > cache_capacity) {
200 const size_t q_bytes = (size_t)num_heads * (
size_t)tokens *
201 (size_t)aligned_head_dim *
sizeof(
float);
202 const size_t attn_bytes = q_bytes;
203 const size_t proj_bytes = (size_t)tokens * (
size_t)aligned_embed_dim *
sizeof(float);
206 uint8_t *scratch_bytes = (uint8_t *)scratch;
207 float *q = (
float *)scratch_bytes;
209 float *attn_out = (
float *)scratch_bytes;
211 float *proj_scratch = (
float *)scratch_bytes;
213 void *qkv_scratch = (
void *)scratch_bytes;
214 (void)qkv_scratch_bytes;
216 float *k_ptr = kv_cache_k + (size_t)start_pos * (
size_t)aligned_head_dim;
217 float *v_ptr = kv_cache_v + (size_t)start_pos * (
size_t)aligned_head_dim;
222 (
const float *)wq, bq,
223 (
const float *)wk, bk,
224 (
const float *)wv, bv,
259 if (rope_cos && rope_sin) {
274 if (start_pos == 0) {
286 const float scale = 1.0f / sqrtf((
float)head_dim);
287 const size_t q_head_stride = (size_t)tokens * (
size_t)aligned_head_dim;
288 const size_t kv_head_stride = (size_t)cache_capacity * (
size_t)aligned_head_dim;
290 for (
int h = 0; h < num_heads; ++h) {
291 int kv_head = (int)((
long long)h * (
long long)num_kv_heads / (
long long)num_heads);
292 const float *k_head = kv_cache_k + (size_t)kv_head * kv_head_stride;
293 const float *v_head = kv_cache_v + (size_t)kv_head * kv_head_stride;
295 for (
int i = 0; i < tokens; ++i) {
296 const float *q_vec = q + (size_t)h * q_head_stride + (
size_t)i * (size_t)aligned_head_dim;
297 float *out_vec = attn_out + (size_t)h * q_head_stride + (
size_t)i * (size_t)aligned_head_dim;
311 if ((num_heads * aligned_head_dim) != aligned_embed_dim) {
317 (aligned_head_dim %
QK5_0) == 0 &&
318 (aligned_embed_dim %
QK5_0) == 0) {
320 uint8_t *attn_q8 = (uint8_t *)q;
335 (aligned_head_dim %
QK5_0) == 0 &&
336 (aligned_embed_dim %
QK5_0) == 0) {
347 (aligned_head_dim %
QK8_0) == 0 &&
348 (aligned_embed_dim %
QK8_0) == 0) {
CKDataType
Supported data types in C-Kernel-Engine.
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void ck_gemm_nt_head_major_q8_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (Q8_0 weights)
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
void rope_forward_qk_strided(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
void ck_gemm_nt_head_major_q5_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (auto-dispatch)
void fused_rmsnorm_qkv_prefill_head_major(const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (head-major outputs)
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void fused_rmsnorm_qkv_prefill_head_major_quant(const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
void ck_gemm_nt_quant(const float *A, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dtype)
void ck_residual_add_token_major(const float *a, const float *b, float *out, int tokens, int aligned_embed_dim)
Quantization block structures for weight-only quantization.
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
size_t mega_fused_attention_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Get scratch buffer size for mega_fused_attention_prefill.
void mega_fused_attention_prefill(float *output, const float *input, const float *residual, const float *ln1_gamma, const void *wq, const float *bq, CKDataType wq_dt, const void *wk, const float *bk, CKDataType wk_dt, const void *wv, const float *bv, CKDataType wv_dt, const void *wo, const float *bo, CKDataType wo_dt, float *kv_cache_k, float *kv_cache_v, const float *rope_cos, const float *rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void *scratch)
Mega-fused attention for prefill mode (multiple tokens)
static size_t align_up_size(size_t value, size_t align)
static void flatten_head_major(const float *attn_out, float *dst, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void out_proj_head_major_q5_0_q8_0(const uint8_t *attn_q8, const void *wo, const float *bias, float *output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
static void quantize_attn_out_head_major_q8_0(const float *attn_out, uint8_t *dst, int tokens, int num_heads, int aligned_head_dim)
static int ck_q8_0_outproj_enabled(void)