#include <stddef.h>#include <stdint.h>#include "cpu_features.h"#include "ckernel_quant.h"#include "mega_fused_attention.h"Go to the source code of this file.
Data Structures | |
| struct | CKMathBackend |
Functions | |
| void | add_backward_bf16 (const uint16_t *d_y, uint16_t *d_a, uint16_t *d_b, size_t n) |
| void | add_forward_2d_bf16 (const uint16_t *a, const uint16_t *b, uint16_t *y, int tokens, int dim, int aligned_dim) |
| void | add_forward_bf16 (const uint16_t *a, const uint16_t *b, uint16_t *y, size_t n) |
| void | add_forward_f32 (const float *a, const float *b, float *y, size_t n) |
| void | add_inplace_bf16 (uint16_t *a, const uint16_t *b, size_t n) |
| void | add_inplace_f32 (float *a, const float *b, size_t n) |
| void | add_scaled_forward_bf16 (const uint16_t *a, const uint16_t *b, uint16_t *y, float alpha, size_t n) |
| void | add_scaled_inplace_bf16 (uint16_t *a, const uint16_t *b, float alpha, size_t n) |
| int | argmax_f32 (const float *scores, int n) |
| Find index of maximum value. More... | |
| void | attention_backward_causal_head_major (const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window) |
| void | attention_backward_causal_head_major_gqa (const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window) |
| void | attention_backward_causal_head_major_gqa_bf16 (const uint16_t *d_output, float *d_x, const uint16_t *q, const uint16_t *k, const uint16_t *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_d_output, float *scratch_q, float *scratch_k, float *scratch_v) |
| void | attention_flash_decode (float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale) |
| Main flash attention function with SIMD dispatch. More... | |
| void | attention_forward_causal_head_major (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window) |
| void | attention_forward_causal_head_major_exact (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window) |
| void | attention_forward_causal_head_major_gqa (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window) |
| void | attention_forward_causal_head_major_gqa_bf16 (const uint16_t *q, const uint16_t *k, const uint16_t *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_q, float *scratch_k, float *scratch_v) |
| void | attention_forward_causal_head_major_gqa_exact (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window) |
| void | attention_forward_causal_head_major_gqa_flash (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim) |
| void | attention_forward_causal_head_major_gqa_flash_strided (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens) |
| void | attention_forward_causal_head_major_gqa_flash_strided_sliding (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window) |
| void | attention_forward_decode_head_major_gqa_flash (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim) |
| void | attention_forward_decode_head_major_gqa_flash_sliding (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window) |
| void | attention_forward_decode_head_major_gqa_regular (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim) |
| WARNING: This is NOT true flash attention! More... | |
| void | axpy_2d_f32 (float *Y, const float *X, float alpha, int num_tokens, int dim, int y_stride, int x_stride) |
| Batched AXPY for 2D tensors: Y[t,:] += alpha * X[t,:]. More... | |
| void | axpy_f32 (float *y, const float *x, float alpha, int n) |
| In-place AXPY: y += alpha * x. More... | |
| void | axpy_zero_f32 (float *y, const float *x, float alpha, int n) |
| Zero output then accumulate: y = 0; y += alpha * x. More... | |
| void | backward_causal_softmax_head_major (float *d_scores, const float *weights, int num_heads, int num_tokens, int aligned_context_window) |
| void | backward_causal_softmax_head_major_bf16 (uint16_t *d_scores, const uint16_t *weights, int num_heads, int num_tokens, int aligned_context_window, float *scratch_d_scores, float *scratch_weights) |
| void | causal_softmax_head_major (float *scores, int num_heads, int num_tokens, int aligned_context_window) |
| void | causal_softmax_head_major_bf16 (uint16_t *scores, int num_heads, int num_tokens, int aligned_context_window, float *scratch) |
| void | causal_softmax_head_major_exact (float *scores, int num_heads, int num_tokens, int aligned_context_window) |
| void | ck_attention_flash_decode_wrapper (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim) |
| Wrapper to call TRUE flash attention from orchestration layer. More... | |
| int | ck_flash_attn_choose_tile_k (int D_h) |
| int | ck_flash_attn_fast_exp_kind (void) |
| void | ck_gemm_nt_head_major_q5_0 (const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim) |
| Output projection from head-major attention (auto-dispatch) More... | |
| void | ck_gemm_nt_head_major_q8_0 (const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim) |
| Output projection from head-major attention (Q8_0 weights) More... | |
| int | ck_get_num_threads (void) |
| int | ck_get_physical_cores (void) |
| void | ck_set_num_threads (int num_threads) |
| void | ck_set_strict_parity (int enabled) |
| int | ck_strict_parity_enabled (void) |
| CKMathBackend | ckernel_backend_native (void) |
| void | dequant_q4_0_row (const void *src, float *dst, size_t n_elements) |
| Dequantize Q4_0 row (multiple blocks) More... | |
| void | dequant_q4_1_row (const void *src, float *dst, size_t n_elements) |
| Dequantize Q4_1 row (multiple blocks) More... | |
| void | dequant_q4_k_row (const void *src, float *dst, size_t n_elements) |
| Dequantize Q4_K row (multiple blocks) More... | |
| void | dequant_q5_0_row (const void *src, float *dst, size_t n_elements) |
| Dequantize Q5_0 row (multiple blocks) More... | |
| void | dequant_q5_1_row (const void *src, float *dst, size_t n_elements) |
| Dequantize Q5_1 row (multiple blocks) More... | |
| void | dequant_q6_k_row (const void *src, float *dst, size_t n_elements) |
| Dequantize Q6_K row (multiple blocks) More... | |
| void | dequant_q8_0_row (const void *src, float *dst, size_t n_elements) |
| Dequantize Q8_0 row (multiple blocks) More... | |
| void | embedding_backward (const int32_t *token_ids, int token_count, const float *d_output, float *d_token_embeddings, float *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos) |
| void | embedding_backward_bf16 (const int32_t *token_ids, int token_count, const uint16_t *d_output, uint16_t *d_token_embeddings, uint16_t *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos) |
| void | embedding_forward (const int32_t *token_ids, int token_count, int vocab_size, const float *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos) |
| void | embedding_forward_bf16 (const int32_t *token_ids, int token_count, int vocab_size, const uint16_t *token_embeddings, const uint16_t *pos_embeddings, uint16_t *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos) |
| void | embedding_forward_q4_k (const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos) |
| void | embedding_forward_q6_k (const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos) |
| void | embedding_forward_q8_0 (const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos) |
| void | fc1_backward_kernel (const float *d_output, const float *fc1_input, const float *W_fc1, float *d_input, float *d_W_fc1, float *d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads) |
| void | fc2_backward_kernel (const float *d_output, const float *fc2_input, const float *W_fc2, float *d_input, float *d_W_fc2, float *d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads) |
| void | fused_mlp_swiglu_decode (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff) |
| void | fused_mlp_swiglu_decode_tiled (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff) |
| void | fused_mlp_swiglu_decode_v2 (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff) |
| void | fused_mlp_swiglu_prefill (const float *x, const float *W_gate, const float *W_up, const float *W_down, float *output, int seq_len, int hidden, int intermediate, float *scratch) |
| Fused MLP (Gate + Up + SwiGLU + Down) for prefill. More... | |
| void | fused_mlp_swiglu_prefill_bias (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *B_gate, const float *B_up, const float *B_down, float *output, int seq_len, int hidden, int intermediate, float *scratch) |
| Fused MLP (Gate + Up + SwiGLU + Down) for prefill with biases. More... | |
| void | fused_mlp_swiglu_prefill_w1w2_quant (const float *x, const void *W1, const float *B1, CKDataType w1_dt, const void *W2, const float *B2, CKDataType w2_dt, float *output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void *scratch) |
| Quantized fused MLP for prefill (W1=gate+up, W2=down) More... | |
| size_t | fused_mlp_swiglu_prefill_w1w2_quant_scratch_size (int aligned_embed_dim, int aligned_intermediate_dim) |
| Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant. More... | |
| size_t | fused_mlp_swiglu_scratch_size (int intermediate) |
| Get scratch buffer size for fused_mlp_swiglu_prefill. More... | |
| void | fused_rmsnorm_qkv_prefill (const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps, float *scratch) |
| Fused RMSNorm + QKV projection for prefill. More... | |
| void | fused_rmsnorm_qkv_prefill_head_major (const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch) |
| Fused RMSNorm + QKV projection for prefill (head-major outputs) More... | |
| void | fused_rmsnorm_qkv_prefill_head_major_quant (const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch) |
| Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations) More... | |
| size_t | fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size (int aligned_embed_dim) |
| Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant. More... | |
| size_t | fused_rmsnorm_qkv_scratch_size (int hidden) |
| Get scratch buffer size for fused_rmsnorm_qkv_prefill. More... | |
| void | geglu_backward_fp32 (const float *x, const float *d_out, float *d_x, int tokens, int dim) |
| void | geglu_forward_bf16 (const uint16_t *x, uint16_t *out, int tokens, int dim, float *scratch) |
| void | geglu_forward_fp32 (const float *x, float *out, int tokens, int dim) |
| void | gelu_backward_exact (const float *input, const float *d_output, float *d_input, size_t n) |
| void | gelu_backward_exact_bf16 (const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input) |
| void | gelu_backward_fast (const float *input, const float *d_output, float *d_input, size_t n) |
| void | gelu_backward_fast_bf16 (const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input) |
| void | gelu_backward_scalar (const float *input, const float *d_output, float *d_input, size_t n) |
| void | gelu_exact_inplace (float *data, size_t n) |
| void | gelu_fast_inplace (float *data, size_t n) |
| void | gelu_fast_inplace_bf16 (uint16_t *data, size_t n, float *scratch) |
| void | gemm_avx512_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_bias_gelu_fused (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_bias_relu_fused (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_bias_silu_fused (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_blocked_serial (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_blocked_serial_bf16 (const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K) |
| void | gemm_fine_grained_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_microkernel (const float *A, const float *B, float *C, int M, int N, int K, int B_transposed) |
| void | gemm_microkernel_blocked (const float *A, const float *B, float *C, int M, int N, int K) |
| void | gemm_microkernel_blocked_bt (const float *A, const float *B, float *C, int M, int N, int K) |
| void | gemm_microkernel_packed (const float *A, const float *B, float *C, int M, int N, int K) |
| void | gemm_naive_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nn_avx512 (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nn_blocked (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nn_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nt_q4_0 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias. More... | |
| void | gemm_nt_q4_1 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| GEMM with transposed Q4_1 weights: C = A @ B^T. More... | |
| void | gemm_nt_q4_k (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nt_q4_k_q8_k (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nt_q5_0 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nt_q5_1 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| GEMM with transposed Q5_1 weights: C = A @ B^T. More... | |
| void | gemm_nt_q5_k (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nt_q6_k (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nt_q6_k_q8_k (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K) |
| NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K. More... | |
| void | gemm_nt_q8_0 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias. More... | |
| void | gemm_nt_q8_0_q8_0 (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K) |
| gemm_nt_q8_0_q8_0 with optional bias (matches header signature) More... | |
| void | gemm_q4_k (float *Y, const void *W, const float *X, int M, int N, int K) |
| Auto-dispatch GEMM based on available SIMD. More... | |
| void | gemm_q4_k_q8_k (float *Y, const void *W, const void *X_q8, int M, int N, int K) |
| void | gemm_q6_k (float *Y, const void *W, const float *X, int M, int N, int K) |
| void | gemm_q6_k_q8_k (float *Y, const void *W, const void *X_q8, int M, int N, int K) |
| GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K. More... | |
| void | gemm_swiglu_fused (const float *x, const float *W_gate, const float *W_up, const float *b_gate, const float *b_up, float *output, int M, int N, int K) |
| void | gemm_tn_avx512 (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_tn_blocked (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_tn_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemv_fused_q5_0_bias_dispatch (float *y, const void *W, const float *x, const float *bias, int M, int K) |
| void | gemv_fused_q8_0_bias_dispatch (float *y, const void *W, const float *x, const float *bias, int M, int K) |
| void | gemv_q4_0 (float *y, const void *W, const float *x, int M, int K) |
| Auto-dispatch GEMV. More... | |
| void | gemv_q4_k (float *y, const void *W, const float *x, int M, int K) |
| Auto-dispatch GEMV based on available SIMD. More... | |
| void | gemv_q4_k_q8_k (float *y, const void *W, const void *x_q8, int M, int K) |
| void | gemv_q4_k_q8_k_parallel (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth) |
| void | gemv_q4_k_q8_k_parallel_simd (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth) |
| void | gemv_q4_k_q8_k_ref (float *y, const void *W, const void *x_q8, int M, int K) |
| void | gemv_q5_0 (float *y, const void *W, const float *x, int M, int K) |
| Auto-dispatch GEMV for Q5_0 weights based on CPU features. More... | |
| void | gemv_q5_0_parallel (float *y, const void *W, const float *x, int M, int K, int ith, int nth) |
| Parallel reference GEMV for Q5_0 × FP32. More... | |
| void | gemv_q5_0_parallel_simd (float *y, const void *W, const float *x, int M, int K, int ith, int nth) |
| Parallel SIMD GEMV for Q5_0 × FP32 with prefetching. More... | |
| void | gemv_q5_0_q8_0 (float *y, const void *W, const void *x_q8, int M, int K) |
| Matrix-vector multiply with Q5_0 weights and Q8_0 input. More... | |
| void | gemv_q5_1 (float *y, const void *W, const float *x, int M, int K) |
| Auto-dispatch GEMV. More... | |
| void | gemv_q5_k (float *y, const void *W, const float *x, int M, int K) |
| void | gemv_q6_k (float *y, const void *W, const float *x, int M, int K) |
| void | gemv_q6_k_q8_k (float *y, const void *W, const void *x_q8, int M, int K) |
| GEMV: y = W @ x where W is Q6_K and x is Q8_K. More... | |
| void | gemv_q6_k_q8_k_parallel (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth) |
| Parallel reference GEMV for Q6_K × Q8_K. More... | |
| void | gemv_q6_k_q8_k_parallel_simd (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth) |
| Parallel SIMD GEMV for Q6_K × Q8_K. More... | |
| void | gemv_q8_0 (float *y, const void *W, const float *x, int M, int K) |
| Auto-dispatch GEMV for Q8_0 weights based on CPU features. More... | |
| void | gemv_q8_0_q8_0 (float *y, const void *W, const void *x_q8, int M, int K) |
| Matrix-vector multiply with Q8_0 weights and Q8_0 input. More... | |
| void | im2patch (const float *image, float *patches, int C, int H, int W, int P) |
| void | im2patch_bf16 (const uint16_t *image, uint16_t *patches, int C, int H, int W, int P) |
| void | kv_cache_repack_head_major_inplace (float *buf, int num_heads, int tokens, int cache_capacity, int aligned_head_dim) |
| void | kv_cache_store (float *__restrict kv_cache_k, float *__restrict kv_cache_v, const float *__restrict k, const float *__restrict v, int layer, int pos, int num_kv_heads, int head_dim, int max_seq_len) |
| void | kv_cache_write_head_major (const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim) |
| void | layernorm_backward_kernel (const float *d_output, const float *input, const float *gamma, const float *mean, const float *rstd, float *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim) |
| void | layernorm_backward_kernel_bf16 (const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *mean, const float *rstd, uint16_t *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input) |
| void | layernorm_forward_rolled_slice (const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps) |
| void | layernorm_forward_rolled_slice_bf16 (const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output) |
| void | layernorm_forward_unrolled_slice (const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps) |
| void | layernorm_forward_unrolled_slice_bf16 (const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps, float *scratch_input, float *scratch_output) |
| void | layernorm_naive_serial (const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps) |
| void | layernorm_naive_serial_matched_precision (const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, float eps) |
| void | mlp_token_parallel (const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads) |
| void | mlp_token_parallel_bf16 (const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16) |
| void | mlp_token_parallel_bf16_fp32act (const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_input_f, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16) |
| void | mlp_token_parallel_exact (const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads) |
| void | moe_accumulate_expert_f32 (float *output, const float *expert_output, float routing_weight, int hidden_dim) |
| Accumulate expert output: output += routing_weight * expert_output. More... | |
| void | patch2im (const float *d_patches, float *d_image, int C, int H, int W, int P) |
| void | patch2im_bf16 (const uint16_t *d_patches, uint16_t *d_image, int C, int H, int W, int P) |
| void | quantize_batch_q8_0 (const float *x, void *y, int num_rows, int k) |
| Batch quantize FP32 to Q8_0 format (row-major output) More... | |
| void | quantize_batch_q8_k (const float *x, void *y, int num_rows, int k) |
| Batch quantize FP32 to Q8_K format (row-major output) More... | |
| void | quantize_row_q8_0 (const float *x, void *y, int k) |
| Quantize FP32 to Q8_0 format (scalar reference) More... | |
| void | quantize_row_q8_k (const float *x, void *y, int k) |
| void | relu_backward (const float *input, const float *d_output, float *d_input, size_t n) |
| void | relu_backward_bf16 (const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n) |
| void | relu_forward (const float *input, float *output, size_t n) |
| void | relu_forward_bf16 (const uint16_t *input, uint16_t *output, size_t n) |
| void | relu_forward_inplace (float *data, size_t n) |
| void | relu_forward_inplace_bf16 (uint16_t *data, size_t n) |
| void | rmsnorm_backward (const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim) |
| void | rmsnorm_backward_bf16 (const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *rstd_cache, uint16_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim) |
| void | rmsnorm_backward_int4 (const uint8_t *d_output, const uint8_t *input, const float *gamma, const float *rstd_cache, uint8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input) |
| void | rmsnorm_backward_int8 (const int8_t *d_output, const int8_t *input, const float *gamma, const float *rstd_cache, int8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input) |
| void | rmsnorm_forward (const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps) |
| void | rmsnorm_forward_bf16 (const uint16_t *input, const float *gamma, uint16_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps) |
| void | rmsnorm_forward_int4 (const uint8_t *input, const float *gamma, uint8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output) |
| void | rmsnorm_forward_int8 (const int8_t *input, const float *gamma, int8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output) |
| void | rope_backward (const float *d_out, float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset) |
| void | rope_backward_bf16 (const uint16_t *d_out, uint16_t *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_d_out, float *scratch_d_x) |
| void | rope_backward_inplace (float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset) |
| void | rope_backward_qk (const float *d_q_out, const float *d_k_out, float *d_q, float *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset) |
| void | rope_backward_qk_bf16 (const uint16_t *d_q_out, const uint16_t *d_k_out, uint16_t *d_q, uint16_t *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_dq_out, float *scratch_dq, float *scratch_dk_out, float *scratch_dk) |
| void | rope_forward (float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset) |
| void | rope_forward_bf16 (uint16_t *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch) |
| void | rope_forward_qk (float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset) |
| void | rope_forward_qk_bf16 (uint16_t *q, uint16_t *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_q, float *scratch_k) |
| void | rope_forward_qk_strided (float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens) |
| void | rope_forward_strided (float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int head_stride_tokens) |
| void | rope_precompute_cache (float *cos_cache, float *sin_cache, int max_seq_len, int head_dim, float base) |
| void | scal_copy_f32 (float *y, const float *x, float alpha, int n) |
| Scaled copy: y = alpha * x. More... | |
| void | sigmoid_backward (const float *input, const float *d_output, float *d_input, size_t n) |
| void | sigmoid_backward_bf16 (const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input) |
| void | sigmoid_forward (const float *input, float *output, size_t n) |
| void | sigmoid_forward_bf16 (const uint16_t *input, uint16_t *output, size_t n, float *scratch_input, float *scratch_output) |
| float | sigmoid_scalar (float x) |
| void | softmax_cross_entropy_loss (const float *logits, const int32_t *targets, int tokens, int vocab_size, float *d_logits, float *loss_out) |
| void | softmax_cross_entropy_loss_bf16 (const uint16_t *logits, const int32_t *targets, int tokens, int vocab_size, uint16_t *d_logits, float *loss_out, float *scratch_logits, float *scratch_d_logits) |
| void | swiglu_backward (const float *input, const float *d_output, float *d_input, int tokens, int dim) |
| void | swiglu_backward_bf16 (const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, int tokens, int dim) |
| void | swiglu_backward_exact (const float *input, const float *d_output, float *d_input, int tokens, int dim) |
| void | swiglu_forward (const float *input, float *output, int tokens, int dim) |
| void | swiglu_forward_bf16 (const uint16_t *input, uint16_t *output, int tokens, int dim) |
| void | swiglu_forward_exact (const float *input, float *output, int tokens, int dim) |
| void | topk_batched_f32 (const float *scores, int num_tokens, int n_experts, int k, int *indices, float *weights) |
| Batched top-K selection for multiple tokens. More... | |
| void | topk_f32 (const float *scores, int n, int k, int *indices, float *values) |
| Find top-K indices and values from a score vector. More... | |
| void | topk_softmax_f32 (const float *scores, int n, int k, int *indices, float *weights) |
| Find top-K indices with softmax-normalized weights. More... | |
| void | unfused_rmsnorm_qkv_prefill (const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *x_norm, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps) |
| Unfused version for benchmarking comparison. More... | |
| void | vec_dot_q6_k_q8_k (int n, float *s, const void *vx, const void *vy) |
| Q6_K x Q8_K dot product (single row) More... | |
| void | weighted_sum_f32 (float *y, const float **vectors, const float *weights, int k, int n) |
| Weighted sum of k vectors: y = sum_i(weights[i] * vectors[i]) More... | |
| void add_backward_bf16 | ( | const uint16_t * | d_y, |
| uint16_t * | d_a, | ||
| uint16_t * | d_b, | ||
| size_t | n | ||
| ) |
Definition at line 173 of file add_kernels_bf16.c.
| void add_forward_2d_bf16 | ( | const uint16_t * | a, |
| const uint16_t * | b, | ||
| uint16_t * | y, | ||
| int | tokens, | ||
| int | dim, | ||
| int | aligned_dim | ||
| ) |
| void add_forward_bf16 | ( | const uint16_t * | a, |
| const uint16_t * | b, | ||
| uint16_t * | y, | ||
| size_t | n | ||
| ) |
| void add_forward_f32 | ( | const float * | a, |
| const float * | b, | ||
| float * | y, | ||
| size_t | n | ||
| ) |
Element-wise add: y = a + b
test_add.py::TestAddForward::test_add_forward_f32
test_add.py::TestAddForward::test_add_inplace_f32
test_multi_layer_parity.py::TestMultiLayerParity::test_residual_add
Element-wise addition of two vectors.
After changes: make test
Definition at line 270 of file add_kernels_bf16.c.
| void add_inplace_bf16 | ( | uint16_t * | a, |
| const uint16_t * | b, | ||
| size_t | n | ||
| ) |
| void add_inplace_f32 | ( | float * | a, |
| const float * | b, | ||
| size_t | n | ||
| ) |
| void add_scaled_forward_bf16 | ( | const uint16_t * | a, |
| const uint16_t * | b, | ||
| uint16_t * | y, | ||
| float | alpha, | ||
| size_t | n | ||
| ) |
| void add_scaled_inplace_bf16 | ( | uint16_t * | a, |
| const uint16_t * | b, | ||
| float | alpha, | ||
| size_t | n | ||
| ) |
| int argmax_f32 | ( | const float * | scores, |
| int | n | ||
| ) |
Find index of maximum value.
| scores | Input scores [n] |
| n | Number of scores |
Definition at line 226 of file topk_kernels.c.
| void attention_backward_causal_head_major | ( | const float * | d_output, |
| const float * | q, | ||
| const float * | k, | ||
| const float * | v, | ||
| const float * | attn_weights, | ||
| float * | d_q, | ||
| float * | d_k, | ||
| float * | d_v, | ||
| float * | d_scores, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window | ||
| ) |
Causal attention backward (non-GQA version)
test_attention_backward.py::TestAttentionBackward::test_backward
test_attention_backward.py::TestAttentionBackward::test_backward_vs_separate
test_parity.py::test_attention_backward_parity
Non-GQA version where num_heads == num_kv_heads. Simpler than GQA, no head broadcasting needed.
After changes: make test && make llamacpp-parity-full
Definition at line 1811 of file attention_kernels.c.
References attention_backward_causal_head_major_gqa().
| void attention_backward_causal_head_major_gqa | ( | const float * | d_output, |
| const float * | q, | ||
| const float * | k, | ||
| const float * | v, | ||
| const float * | attn_weights, | ||
| float * | d_q, | ||
| float * | d_k, | ||
| float * | d_v, | ||
| float * | d_scores, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window | ||
| ) |
GQA causal attention backward (score-matrix version)
test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward
test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_vs_separate
test_parity.py::test_attention_backward_parity
Computes dQ, dK, dV given dOutput and attention weights. Supports grouped-query attention with head broadcasting.
After changes: make test && make llamacpp-parity-full
Definition at line 1672 of file attention_kernels.c.
References qkv_index(), and score_index().
Referenced by attention_backward_causal_head_major(), attention_backward_causal_head_major_gqa_bf16(), and ck_layer_backward_rmsnorm_swiglu().
| void attention_backward_causal_head_major_gqa_bf16 | ( | const uint16_t * | d_output, |
| float * | d_x, | ||
| const uint16_t * | q, | ||
| const uint16_t * | k, | ||
| const uint16_t * | v, | ||
| const float * | attn_weights, | ||
| float * | d_q, | ||
| float * | d_k, | ||
| float * | d_v, | ||
| float * | d_scores, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window, | ||
| float * | scratch_d_output, | ||
| float * | scratch_q, | ||
| float * | scratch_k, | ||
| float * | scratch_v | ||
| ) |
BF16 attention backward with caller-provided scratch buffers
Accepts BF16 inputs, converts to FP32, runs FP32 backward. Caller provides scratch buffers (no per-call malloc).
After changes: make test
Definition at line 1619 of file attention_kernels.c.
References attention_backward_causal_head_major_gqa(), and convert_bf16_tensor_to_buf().
| void attention_flash_decode | ( | float * | out, |
| const float * | q, | ||
| const float * | k, | ||
| const float * | v, | ||
| int | T_q, | ||
| int | T_k, | ||
| int | H, | ||
| int | D_h, | ||
| float | scale | ||
| ) |
Main flash attention function with SIMD dispatch.
| out | Output [T_q, H, D_h] |
| q | Query [T_q, H, D_h] |
| k | Key [T_k, H, D_h] |
| v | Value [T_k, H, D_h] |
| T_q | Number of query tokens (1 for decode) |
| T_k | Number of key/value tokens (context length) |
| H | Number of heads |
| D_h | Head dimension |
| scale | 1/sqrt(D_h) |
Definition at line 696 of file attention_flash_true.c.
References attention_flash_decode_scalar().
Referenced by attention_forward_decode_head_major_gqa_flash(), ck_attention_flash_decode_wrapper(), mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().
| void attention_forward_causal_head_major | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | scores, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window | ||
| ) |
Causal attention forward (score-matrix version)
test_attention.py::TestAttentionForward::test_causal_forward
test_attention.py::TestAttentionForward::test_gqa_broadcast
test_attention.py::TestAttentionForward::test_exact_vs_fast
test_parity.py::test_attention_parity
Computes softmax(Q @ K^T / sqrt(d)) @ V with causal masking. Uses O(N^2) memory for scores matrix.
After changes: make test && make llamacpp-parity-full
Definition at line 70 of file attention_kernels.c.
References causal_softmax_head_major(), qkv_index(), and score_index().
| void attention_forward_causal_head_major_exact | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | scores, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window | ||
| ) |
Causal attention forward (exact version using stdlib expf)
test_attention.py::TestAttentionForward::test_exact_single
test_attention.py::TestAttentionForward::test_exact_vs_fast
Uses standard library expf for numerical accuracy reference. Slower but provides maximum accuracy.
After changes: make test
Definition at line 146 of file attention_kernels.c.
References causal_softmax_head_major_exact(), qkv_index(), and score_index().
| void attention_forward_causal_head_major_gqa | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | scores, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window | ||
| ) |
GQA causal attention forward (score-matrix version)
test_attention.py::TestAttentionForward::test_gqa_forward
test_attention.py::TestAttentionForward::test_gqa_broadcast
test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward
test_parity.py::test_attention_gqa_parity
Grouped-query attention: Q has num_heads, K/V have num_kv_heads. Each query head maps to a KV head via ratio.
After changes: make test && make llamacpp-parity-full
Definition at line 224 of file attention_kernels.c.
References causal_softmax_head_major(), qkv_index(), and score_index().
Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), and ck_layer_forward_rmsnorm_swiglu_ref().
| void attention_forward_causal_head_major_gqa_bf16 | ( | const uint16_t * | q, |
| const uint16_t * | k, | ||
| const uint16_t * | v, | ||
| float * | scores, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window, | ||
| float * | scratch_q, | ||
| float * | scratch_k, | ||
| float * | scratch_v | ||
| ) |
BF16 GQA causal attention forward
bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_forward
bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa
bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_flash
Accepts BF16 inputs, converts to FP32, uses exact softmax. Caller provides scratch buffers (no per-call malloc).
After changes: make test
Definition at line 366 of file attention_kernels.c.
References attention_forward_causal_head_major_gqa_exact(), and convert_bf16_tensor_to_buf().
| void attention_forward_causal_head_major_gqa_exact | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | scores, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | aligned_context_window | ||
| ) |
GQA causal attention forward (exact version using stdlib expf)
test_attention.py::TestAttentionForward::test_gqa_exact
bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa
Uses standard library expf for numerical accuracy reference. Used by BF16 wrapper to avoid approximation error accumulation.
After changes: make test
Definition at line 294 of file attention_kernels.c.
References causal_softmax_head_major_exact(), qkv_index(), and score_index().
Referenced by attention_forward_causal_head_major_gqa_bf16().
| void attention_forward_causal_head_major_gqa_flash | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim | ||
| ) |
Flash attention forward for GQA (prefill, no score materialization)
test_flash_attention.py::TestFlashAttention::test_flash_forward
test_flash_attention.py::TestFlashAttention::test_flash_vs_score_matrix
test_flash_attention.py::TestFlashAttention::test_flash_gqa
test_attention.py::TestAttentionForward::test_flash_forward
Online softmax with streaming KV. O(N) memory instead of O(N^2). For prefill: all tokens attend to previous tokens.
After changes: make test && make llamacpp-parity-full
Definition at line 800 of file attention_kernels.c.
References FLASH_QUERY_IMPL, and qkv_index().
Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), ck_layer_forward_rmsnorm_swiglu_ref(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), and qwen2_0_5b_decode_layer_9_prefill().
| void attention_forward_causal_head_major_gqa_flash_strided | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | kv_stride_tokens | ||
| ) |
Flash attention forward with custom KV stride (for KV cache)
test_flash_attention.py::TestFlashAttention::test_flash_strided
test_kv_cache_attention.py::TestKVCacheAttention::test_flash_attention
Variant with configurable kv_stride_tokens for KV cache layouts where K/V may not be contiguous in memory.
After changes: make test
Definition at line 859 of file attention_kernels.c.
References FLASH_QUERY_IMPL, and qkv_index().
Referenced by mega_fused_attention_prefill(), mega_fused_attention_prefill_q8_0(), model_layer_0_prefill(), model_layer_10_prefill(), model_layer_11_prefill(), model_layer_12_prefill(), model_layer_13_prefill(), model_layer_14_prefill(), model_layer_15_prefill(), model_layer_16_prefill(), model_layer_17_prefill(), model_layer_18_prefill(), model_layer_19_prefill(), model_layer_1_prefill(), model_layer_20_prefill(), model_layer_21_prefill(), model_layer_22_prefill(), model_layer_23_prefill(), model_layer_2_prefill(), model_layer_3_prefill(), model_layer_4_prefill(), model_layer_5_prefill(), model_layer_6_prefill(), model_layer_7_prefill(), model_layer_8_prefill(), model_layer_9_prefill(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), and qwen2_0_5b_decode_layer_9_prefill().
| void attention_forward_causal_head_major_gqa_flash_strided_sliding | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | kv_stride_tokens, | ||
| int | sliding_window | ||
| ) |
Flash attention forward with sliding window (prefill)
Sliding-window attention for prefill: each token attends to the last W tokens. When sliding_window <= 0, behaves like regular causal attention.
After changes: make test
Definition at line 1316 of file attention_kernels.c.
References qkv_index(), and SLIDING_FLASH_IMPL.
| void attention_forward_decode_head_major_gqa_flash | ( | const float * | q_token, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| float * | out_token, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | kv_tokens, | ||
| int | cache_capacity, | ||
| int | head_dim, | ||
| int | aligned_head_dim | ||
| ) |
Flash attention decode (single token attends to KV cache)
test_flash_attention.py::TestFlashAttention::test_flash_decode
test_kv_cache_attention.py::TestKVCacheAttention::test_flash_decode
test_fused_attention_decode.py::TestFusedAttentionDecode::test_flash_decode
test_attention.py::TestAttentionForward::test_flash_decode
Single query token attends to kv_tokens in KV cache. Uses true flash attention from attention_flash_true.c.
After changes: make test && make llamacpp-parity-full
Definition at line 1467 of file attention_kernels.c.
References attention_flash_decode().
Referenced by model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().
| void attention_forward_decode_head_major_gqa_flash_sliding | ( | const float * | q_token, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| float * | out_token, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | kv_tokens, | ||
| int | cache_capacity, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | sliding_window | ||
| ) |
Flash attention decode with sliding window
Single query token attends to the last W tokens in the KV cache. For decode: effective_kv_tokens = min(kv_tokens, sliding_window)
After changes: make test
Definition at line 1382 of file attention_kernels.c.
References SLIDING_DECODE_IMPL.
| void attention_forward_decode_head_major_gqa_regular | ( | const float * | q_token, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| float * | out_token, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | kv_tokens, | ||
| int | cache_capacity, | ||
| int | head_dim, | ||
| int | aligned_head_dim | ||
| ) |
WARNING: This is NOT true flash attention!
This function is named "flash" but implements regular attention with O(n) complexity. It's kept for reference and as a fallback.
TRUE flash attention is implemented in attention_flash_true.c
test_kv_cache_attention.py::TestKVCacheAttention::test_regular_decode
test_attention.py::TestAttentionForward::test_regular_decode
Regular attention decode (score-matrix version) for fallback.
After changes: make test
Definition at line 1524 of file attention_kernels.c.
References FLASH_QUERY_IMPL_DECODE.
Referenced by ck_attention_flash_decode_wrapper(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().
| void axpy_2d_f32 | ( | float * | Y, |
| const float * | X, | ||
| float | alpha, | ||
| int | num_tokens, | ||
| int | dim, | ||
| int | y_stride, | ||
| int | x_stride | ||
| ) |
Batched AXPY for 2D tensors: Y[t,:] += alpha * X[t,:].
| Y | Output tensor [num_tokens, dim] |
| X | Input tensor [num_tokens, dim] |
| alpha | Scalar multiplier |
| num_tokens | Number of tokens |
| dim | Hidden dimension |
| y_stride | Stride between Y rows (for alignment) |
| x_stride | Stride between X rows |
Definition at line 221 of file axpy_kernels.c.
References axpy_f32().
| void axpy_f32 | ( | float * | y, |
| const float * | x, | ||
| float | alpha, | ||
| int | n | ||
| ) |
In-place AXPY: y += alpha * x.
test_axpy.py::TestAXPY::test_axpy_f32
test_axpy.py::TestAXPY::test_axpy_vs_naive
In-place scaled vector addition: y += alpha * x BLAS-like axpy operation.
After changes: make test
Definition at line 54 of file axpy_kernels.c.
Referenced by axpy_2d_f32(), axpy_zero_f32(), moe_accumulate_expert_f32(), and weighted_sum_f32().
| void axpy_zero_f32 | ( | float * | y, |
| const float * | x, | ||
| float | alpha, | ||
| int | n | ||
| ) |
Zero output then accumulate: y = 0; y += alpha * x.
| y | Output vector [n], zeroed then accumulated |
| x | Input vector [n] |
| alpha | Scalar multiplier |
| n | Vector length |
Definition at line 188 of file axpy_kernels.c.
References axpy_f32().
| void backward_causal_softmax_head_major | ( | float * | d_scores, |
| const float * | weights, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | aligned_context_window | ||
| ) |
Definition at line 382 of file softmax_kernels.c.
Referenced by backward_causal_softmax_head_major_bf16().
| void backward_causal_softmax_head_major_bf16 | ( | uint16_t * | d_scores, |
| const uint16_t * | weights, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | aligned_context_window, | ||
| float * | scratch_d_scores, | ||
| float * | scratch_weights | ||
| ) |
Definition at line 53 of file softmax_kernels_bf16.c.
References backward_causal_softmax_head_major(), bf16_tensor_to_float(), and float_tensor_to_bf16().
| void causal_softmax_head_major | ( | float * | scores, |
| int | num_heads, | ||
| int | num_tokens, | ||
| int | aligned_context_window | ||
| ) |
Causal softmax (in-place, row-wise)
test_softmax.py::TestSoftmaxForward::test_causal_softmax
test_softmax.py::TestSoftmaxForward::test_causal_vs_softmax
test_attention.py::TestAttentionForward::test_softmax_correctness
Applies causal mask (j > i => 0) and softmax to scores matrix. In-place on [num_heads, T, T] scores matrix.
After changes: make test && make llamacpp-parity-full
Definition at line 144 of file softmax_kernels.c.
Referenced by attention_forward_causal_head_major(), attention_forward_causal_head_major_gqa(), and causal_softmax_head_major_bf16().
| void causal_softmax_head_major_bf16 | ( | uint16_t * | scores, |
| int | num_heads, | ||
| int | num_tokens, | ||
| int | aligned_context_window, | ||
| float * | scratch | ||
| ) |
Definition at line 31 of file softmax_kernels_bf16.c.
References bf16_tensor_to_float(), causal_softmax_head_major(), and float_tensor_to_bf16().
| void causal_softmax_head_major_exact | ( | float * | scores, |
| int | num_heads, | ||
| int | num_tokens, | ||
| int | aligned_context_window | ||
| ) |
Causal softmax (exact version using stdlib expf)
test_softmax.py::TestSoftmaxForward::test_causal_softmax_exact
test_softmax.py::TestSoftmaxForward::test_exact_vs_fast
Exact causal softmax using standard library expf for numerical accuracy reference.
After changes: make test
Definition at line 339 of file softmax_kernels.c.
Referenced by attention_forward_causal_head_major_exact(), and attention_forward_causal_head_major_gqa_exact().
| void ck_attention_flash_decode_wrapper | ( | const float * | q_token, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| float * | out_token, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | kv_tokens, | ||
| int | cache_capacity, | ||
| int | head_dim, | ||
| int | aligned_head_dim | ||
| ) |
Wrapper to call TRUE flash attention from orchestration layer.
| q_token | Query token [H, D_h] |
| k_cache | Cached keys [T_k, H, D_h] |
| v_cache | Cached values [T_k, H, D_h] |
| out_token | Output [H, D_h] |
| num_heads | Number of heads |
| num_kv_heads | Number of KV heads (for GQA) |
| kv_tokens | Number of tokens in KV cache |
| cache_capacity | Cache capacity |
| head_dim | Head dimension |
| aligned_head_dim | Aligned head dimension |
Definition at line 72 of file ckernel_orchestration.c.
References attention_flash_decode(), and attention_forward_decode_head_major_gqa_regular().
Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), and ck_layer_forward_rmsnorm_swiglu_decode_quant().
| int ck_flash_attn_choose_tile_k | ( | int | D_h | ) |
| int ck_flash_attn_fast_exp_kind | ( | void | ) |
Definition at line 112 of file attention_flash_true.c.
| void ck_gemm_nt_head_major_q5_0 | ( | const float * | attn_out, |
| const void * | wo, | ||
| const float * | bias, | ||
| float * | output, | ||
| int | tokens, | ||
| int | embed_dim, | ||
| int | num_heads, | ||
| int | head_dim | ||
| ) |
Output projection from head-major attention (auto-dispatch)
This replaces flatten_head_major() + ck_gemm_nt_quant() with a single strided-access kernel that reads head-major attention output directly.
Definition at line 328 of file gemm_head_major_output.c.
References gemv_nt_q5_0_head_major_output().
Referenced by mega_fused_attention_prefill().
| void ck_gemm_nt_head_major_q8_0 | ( | const float * | attn_out, |
| const void * | wo, | ||
| const float * | bias, | ||
| float * | output, | ||
| int | tokens, | ||
| int | embed_dim, | ||
| int | num_heads, | ||
| int | head_dim | ||
| ) |
Output projection from head-major attention (Q8_0 weights)
Definition at line 353 of file gemm_head_major_output.c.
References CK_FP16_TO_FP32, block_q8_0::d, QK8_0, and block_q8_0::qs.
Referenced by mega_fused_attention_prefill().
| int ck_get_num_threads | ( | void | ) |
Definition at line 178 of file ckernel_strict.c.
References ck_set_num_threads(), g_num_threads, and g_threads_initialized.
Referenced by gemm_blocked_serial().
| int ck_get_physical_cores | ( | void | ) |
Definition at line 62 of file ckernel_strict.c.
References CK_ADD_PAIR.
| void ck_set_num_threads | ( | int | num_threads | ) |
Definition at line 148 of file ckernel_strict.c.
References ck_get_physical_cores(), ck_parse_env_int(), g_num_threads, and g_threads_initialized.
Referenced by ck_get_num_threads().
| void ck_set_strict_parity | ( | int | enabled | ) |
| int ck_strict_parity_enabled | ( | void | ) |
Definition at line 33 of file ckernel_strict.c.
References ck_strict_parity.
Referenced by ck_q8k_activations_enabled(), gemm_avx512_parallel(), gemm_blocked_serial(), gemm_fine_grained_parallel(), gemm_naive_parallel(), gemm_nn_avx512(), gemm_nn_blocked(), gemm_nn_parallel(), gemm_tn_avx512(), gemm_tn_blocked(), and gemm_tn_parallel().
| CKMathBackend ckernel_backend_native | ( | void | ) |
Obtain the built-in native backend (single-node CPU, C + intrinsics).
Definition at line 39 of file backend_native.c.
References ckernel_sgemm_native(), and CKMathBackend::sgemm.
| void dequant_q4_0_row | ( | const void * | src, |
| float * | dst, | ||
| size_t | n_elements | ||
| ) |
Dequantize Q4_0 row (multiple blocks)
| src | Q4_0 data |
| dst | FP32 output |
| n_elements | Number of elements to dequantize |
Definition at line 61 of file dequant_kernels.c.
References dequant_q4_0_block(), and QK4_0.
| void dequant_q4_1_row | ( | const void * | src, |
| float * | dst, | ||
| size_t | n_elements | ||
| ) |
Dequantize Q4_1 row (multiple blocks)
Definition at line 139 of file dequant_kernels.c.
References dequant_q4_1_block(), and QK4_1.
Referenced by dequant_row().
| void dequant_q4_k_row | ( | const void * | src, |
| float * | dst, | ||
| size_t | n_elements | ||
| ) |
Dequantize Q4_K row (multiple blocks)
Definition at line 370 of file dequant_kernels.c.
References dequant_q4_k_block(), and QK_K.
Referenced by embedding_forward_q4_k().
| void dequant_q5_0_row | ( | const void * | src, |
| float * | dst, | ||
| size_t | n_elements | ||
| ) |
Dequantize Q5_0 row (multiple blocks)
Definition at line 196 of file dequant_kernels.c.
References dequant_q5_0_block(), and QK5_0.
| void dequant_q5_1_row | ( | const void * | src, |
| float * | dst, | ||
| size_t | n_elements | ||
| ) |
Dequantize Q5_1 row (multiple blocks)
Definition at line 255 of file dequant_kernels.c.
References dequant_q5_1_block(), and QK5_1.
| void dequant_q6_k_row | ( | const void * | src, |
| float * | dst, | ||
| size_t | n_elements | ||
| ) |
Dequantize Q6_K row (multiple blocks)
Definition at line 420 of file dequant_kernels.c.
References dequant_q6_k_block(), and QK_K.
Referenced by embedding_forward_q6_k().
| void dequant_q8_0_row | ( | const void * | src, |
| float * | dst, | ||
| size_t | n_elements | ||
| ) |
Dequantize Q8_0 row (multiple blocks)
Definition at line 286 of file dequant_kernels.c.
References dequant_q8_0_block(), and QK8_0.
Referenced by dequant_row(), and embedding_forward_q8_0().
| void embedding_backward | ( | const int32_t * | token_ids, |
| int | token_count, | ||
| const float * | d_output, | ||
| float * | d_token_embeddings, | ||
| float * | d_pos_embeddings, | ||
| int | vocab_size, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | context_window, | ||
| int | add_pos | ||
| ) |
Definition at line 241 of file embedding_kernels.c.
References vocab_size.
| void embedding_backward_bf16 | ( | const int32_t * | token_ids, |
| int | token_count, | ||
| const uint16_t * | d_output, | ||
| uint16_t * | d_token_embeddings, | ||
| uint16_t * | d_pos_embeddings, | ||
| int | vocab_size, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | context_window, | ||
| int | add_pos | ||
| ) |
Definition at line 72 of file embedding_kernels_bf16.c.
References bf16_to_float(), float_to_bf16(), and vocab_size.
| void embedding_forward | ( | const int32_t * | token_ids, |
| int | token_count, | ||
| int | vocab_size, | ||
| const float * | token_embeddings, | ||
| const float * | pos_embeddings, | ||
| float * | output, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | context_window, | ||
| int | add_pos | ||
| ) |
Definition at line 22 of file embedding_kernels.c.
References vocab_size.
| void embedding_forward_bf16 | ( | const int32_t * | token_ids, |
| int | token_count, | ||
| int | vocab_size, | ||
| const uint16_t * | token_embeddings, | ||
| const uint16_t * | pos_embeddings, | ||
| uint16_t * | output, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | context_window, | ||
| int | add_pos | ||
| ) |
Definition at line 21 of file embedding_kernels_bf16.c.
References bf16_to_float(), float_to_bf16(), and vocab_size.
| void embedding_forward_q4_k | ( | const int32_t * | token_ids, |
| int | token_count, | ||
| int | vocab_size, | ||
| const void * | token_embeddings, | ||
| const float * | pos_embeddings, | ||
| float * | output, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | context_window, | ||
| int | add_pos | ||
| ) |
Definition at line 76 of file embedding_kernels.c.
References CK_DT_Q4_K, ck_dtype_row_bytes(), dequant_q4_k_row(), and vocab_size.
Referenced by model_decode_token(), model_forward_prefill_impl(), qwen2_0_5b_decode_decode_token(), and qwen2_0_5b_decode_forward_prefill_impl().
| void embedding_forward_q6_k | ( | const int32_t * | token_ids, |
| int | token_count, | ||
| int | vocab_size, | ||
| const void * | token_embeddings, | ||
| const float * | pos_embeddings, | ||
| float * | output, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | context_window, | ||
| int | add_pos | ||
| ) |
Definition at line 186 of file embedding_kernels.c.
References CK_DT_Q6_K, ck_dtype_row_bytes(), dequant_q6_k_row(), and vocab_size.
| void embedding_forward_q8_0 | ( | const int32_t * | token_ids, |
| int | token_count, | ||
| int | vocab_size, | ||
| const void * | token_embeddings, | ||
| const float * | pos_embeddings, | ||
| float * | output, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | context_window, | ||
| int | add_pos | ||
| ) |
Definition at line 131 of file embedding_kernels.c.
References CK_DT_Q8_0, ck_dtype_row_bytes(), dequant_q8_0_row(), and vocab_size.
Referenced by qwen2_0_5b_decode_decode_token(), and qwen2_0_5b_decode_forward_prefill_impl().
| void fc1_backward_kernel | ( | const float * | d_output, |
| const float * | fc1_input, | ||
| const float * | W_fc1, | ||
| float * | d_input, | ||
| float * | d_W_fc1, | ||
| float * | d_b_fc1, | ||
| int | T, | ||
| int | aligned_in, | ||
| int | aligned_out, | ||
| int | num_threads | ||
| ) |
Definition at line 167 of file mlp_kernels.c.
References gemm_nn_avx512(), and gemm_tn_avx512().
Referenced by ck_layer_backward_rmsnorm_swiglu().
| void fc2_backward_kernel | ( | const float * | d_output, |
| const float * | fc2_input, | ||
| const float * | W_fc2, | ||
| float * | d_input, | ||
| float * | d_W_fc2, | ||
| float * | d_b_fc2, | ||
| int | T, | ||
| int | aligned_in, | ||
| int | aligned_out, | ||
| int | num_threads | ||
| ) |
Definition at line 118 of file mlp_kernels.c.
References gemm_nn_avx512(), and gemm_tn_avx512().
Referenced by ck_attention_project_head_major_backward(), ck_layer_backward_rmsnorm_swiglu(), and ck_qkv_project_head_major_backward().
| void fused_mlp_swiglu_decode | ( | const float * | x, |
| const float * | W_gate, | ||
| const float * | W_up, | ||
| const float * | W_down, | ||
| const float * | b_gate, | ||
| const float * | b_up, | ||
| const float * | b_down, | ||
| float * | output, | ||
| int | D, | ||
| int | Hff | ||
| ) |
Definition at line 154 of file mlp_fused_decode.c.
References __attribute__(), MLP_TILE_SIZE, and silu_scalar().
| void fused_mlp_swiglu_decode_tiled | ( | const float * | x, |
| const float * | W_gate, | ||
| const float * | W_up, | ||
| const float * | W_down, | ||
| const float * | b_gate, | ||
| const float * | b_up, | ||
| const float * | b_down, | ||
| float * | output, | ||
| int | D, | ||
| int | Hff | ||
| ) |
Definition at line 429 of file mlp_fused_decode.c.
References __attribute__(), and silu_scalar().
Referenced by fused_mlp_swiglu_decode_v2().
| void fused_mlp_swiglu_decode_v2 | ( | const float * | x, |
| const float * | W_gate, | ||
| const float * | W_up, | ||
| const float * | W_down, | ||
| const float * | b_gate, | ||
| const float * | b_up, | ||
| const float * | b_down, | ||
| float * | output, | ||
| int | D, | ||
| int | Hff | ||
| ) |
Definition at line 318 of file mlp_fused_decode.c.
References __attribute__(), fused_mlp_swiglu_decode_tiled(), MAX_SWIGLU_STACK, and silu_scalar().
Referenced by ck_mlp_swiglu_forward_fully_fused_token().
| void fused_mlp_swiglu_prefill | ( | const float * | x, |
| const float * | W_gate, | ||
| const float * | W_up, | ||
| const float * | W_down, | ||
| float * | output, | ||
| int | seq_len, | ||
| int | hidden, | ||
| int | intermediate, | ||
| float * | scratch | ||
| ) |
Fused MLP (Gate + Up + SwiGLU + Down) for prefill.
Tiles along token dimension to keep gate/up/hidden in L3 cache.
| scratch | Temporary buffer from fused_mlp_swiglu_scratch_size() |
Definition at line 879 of file prefill_fused_gemm.c.
References fused_mlp_swiglu_prefill_bias().
| void fused_mlp_swiglu_prefill_bias | ( | const float * | x, |
| const float * | W_gate, | ||
| const float * | W_up, | ||
| const float * | W_down, | ||
| const float * | B_gate, | ||
| const float * | B_up, | ||
| const float * | B_down, | ||
| float * | output, | ||
| int | seq_len, | ||
| int | hidden, | ||
| int | intermediate, | ||
| float * | scratch | ||
| ) |
Fused MLP (Gate + Up + SwiGLU + Down) for prefill with biases.
Fused MLP (Gate + Up + SwiGLU + Down) for prefill with biases.
Definition at line 746 of file prefill_fused_gemm.c.
References add_bias_tile(), gemm_tile_nt_strided(), PREFILL_TILE_M, and silu().
Referenced by fused_mlp_swiglu_prefill().
| void fused_mlp_swiglu_prefill_w1w2_quant | ( | const float * | x, |
| const void * | W1, | ||
| const float * | B1, | ||
| CKDataType | w1_dt, | ||
| const void * | W2, | ||
| const float * | B2, | ||
| CKDataType | w2_dt, | ||
| float * | output, | ||
| int | seq_len, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | intermediate_dim, | ||
| int | aligned_intermediate_dim, | ||
| void * | scratch | ||
| ) |
Quantized fused MLP for prefill (W1=gate+up, W2=down)
W1 uses Q8_0 activations (Q5_0/Q8_0 weights), W2 uses Q8_K activations (Q4_K/Q6_K weights).
Uses Q8_0 activations for W1 (Q5_0/Q8_0 weights) and Q8_K activations for W2 (Q4_K/Q6_K weights).
Definition at line 965 of file prefill_fused_gemm.c.
References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), gemm_nt_q8_0_mlp_dispatch(), gemm_nt_q8_k_mlp_dispatch(), mlp_q8_0_dtype_supported(), mlp_q8_k_dtype_supported(), PREFILL_TILE_M, quantize_row_q8_0(), quantize_row_q8_k(), and silu_prefill().
Referenced by mega_fused_outproj_mlp_prefill().
| size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size | ( | int | aligned_embed_dim, |
| int | aligned_intermediate_dim | ||
| ) |
Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.
Definition at line 1063 of file prefill_fused_gemm.c.
References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), and PREFILL_TILE_M.
Referenced by mega_fused_outproj_mlp_prefill_scratch_size().
| size_t fused_mlp_swiglu_scratch_size | ( | int | intermediate | ) |
Get scratch buffer size for fused_mlp_swiglu_prefill.
Get scratch buffer size for fused_mlp_swiglu_prefill.
Definition at line 899 of file prefill_fused_gemm.c.
References PREFILL_TILE_M.
| void fused_rmsnorm_qkv_prefill | ( | const float * | x, |
| const float * | gamma, | ||
| const float * | Wq, | ||
| const float * | Wk, | ||
| const float * | Wv, | ||
| float * | Q, | ||
| float * | K, | ||
| float * | V, | ||
| int | seq_len, | ||
| int | hidden, | ||
| int | q_dim, | ||
| int | kv_dim, | ||
| float | eps, | ||
| float * | scratch | ||
| ) |
Fused RMSNorm + QKV projection for prefill.
Tiles along token dimension to keep intermediate x_norm in L2 cache. Avoids ~7MB DRAM traffic per layer for seq_len=1024, hidden=896.
| scratch | Temporary buffer from fused_rmsnorm_qkv_scratch_size() |
Fused RMSNorm + QKV projection for prefill.
KEY INSIGHT: For Qwen2-0.5B, all QKV weights fit in L3: Wq (896×896) + Wk (128×896) + Wv (128×896) = 4.1MB < 6MB L3
So we use M-tiling (tokens) only:
This avoids both:
Definition at line 393 of file prefill_fused_gemm.c.
References gemm_tile_nt_strided(), PREFILL_TILE_M, and rmsnorm_tile().
| void fused_rmsnorm_qkv_prefill_head_major | ( | const float * | x, |
| const float * | gamma, | ||
| const float * | Wq, | ||
| const float * | Bq, | ||
| const float * | Wk, | ||
| const float * | Bk, | ||
| const float * | Wv, | ||
| const float * | Bv, | ||
| float * | Q, | ||
| float * | K, | ||
| float * | V, | ||
| int | seq_len, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | kv_stride_tokens, | ||
| float | eps, | ||
| float * | scratch | ||
| ) |
Fused RMSNorm + QKV projection for prefill (head-major outputs)
Writes Q as [num_heads, seq_len, aligned_head_dim] and K/V with stride kv_stride_tokens for KV-cache compatibility.
Q is written as [num_heads, seq_len, aligned_head_dim]. K/V are written with kv_stride_tokens for KV-cache compatibility.
Definition at line 441 of file prefill_fused_gemm.c.
References add_bias_tile(), gemm_tile_nt_strided(), PREFILL_TILE_M, and rmsnorm_tile().
Referenced by mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().
| void fused_rmsnorm_qkv_prefill_head_major_quant | ( | const float * | x, |
| const float * | gamma, | ||
| const void * | Wq, | ||
| const float * | Bq, | ||
| CKDataType | wq_dt, | ||
| const void * | Wk, | ||
| const float * | Bk, | ||
| CKDataType | wk_dt, | ||
| const void * | Wv, | ||
| const float * | Bv, | ||
| CKDataType | wv_dt, | ||
| float * | Q, | ||
| float * | K, | ||
| float * | V, | ||
| int | seq_len, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | kv_stride_tokens, | ||
| float | eps, | ||
| void * | scratch | ||
| ) |
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
Supports Q5_0 or Q8_0 weights with Q8_0 activations.
Supports Q5_0 or Q8_0 weights with Q8_0 activations. Writes K/V directly into KV cache layout (kv_stride_tokens).
Definition at line 519 of file prefill_fused_gemm.c.
References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), gemm_nt_q8_0_dispatch(), gemm_nt_q8_k_qkv_dispatch(), PREFILL_TILE_M, qkv_q8_0_dtype_supported(), qkv_q8_k_dtype_supported(), quantize_row_q8_0(), quantize_row_q8_k(), and rmsnorm_tile().
Referenced by mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().
| size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size | ( | int | aligned_embed_dim | ) |
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
Definition at line 651 of file prefill_fused_gemm.c.
References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), and PREFILL_TILE_M.
Referenced by mega_fused_attention_prefill(), mega_fused_attention_prefill_q8_0(), mega_fused_attention_prefill_q8_0_scratch_size(), and mega_fused_attention_prefill_scratch_size().
| size_t fused_rmsnorm_qkv_scratch_size | ( | int | hidden | ) |
Get scratch buffer size for fused_rmsnorm_qkv_prefill.
Get scratch buffer size for fused_rmsnorm_qkv_prefill.
Definition at line 739 of file prefill_fused_gemm.c.
References PREFILL_TILE_M.
| void geglu_backward_fp32 | ( | const float * | x, |
| const float * | d_out, | ||
| float * | d_x, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
GeGLU backward pass (fp32)
dL/dx given dL/d(out) where out = GELU(a) * b Chain rule: dL/da = dL/dout * d(GELU)/da * b dL/db = dL/dout * GELU(a)
After changes: make test
Definition at line 843 of file gelu_kernels.c.
| void geglu_forward_bf16 | ( | const uint16_t * | x, |
| uint16_t * | out, | ||
| int | tokens, | ||
| int | dim, | ||
| float * | scratch | ||
| ) |
GeGLU forward pass (bf16)
BF16 version: converts to FP32, computes, converts back. Caller provides scratch buffer of size 3 * tokens * dim * sizeof(float).
Layout:
Note: We need separate buffers for input and output to avoid overlap when tokens > 1. The input is 2*dim per token, output is dim per token.
After changes: make test
Definition at line 813 of file gelu_kernels.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), and geglu_forward_fp32().
| void geglu_forward_fp32 | ( | const float * | x, |
| float * | out, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
GeGLU forward pass (fp32)
Computes out = GELU(a) * b where x = [a, b] along last dimension. Input shape: [tokens, 2 * dim], Output shape: [tokens, dim]
After changes: make test
Definition at line 623 of file gelu_kernels.c.
References __attribute__().
| void gelu_backward_exact | ( | const float * | input, |
| const float * | d_output, | ||
| float * | d_input, | ||
| size_t | n | ||
| ) |
| void gelu_backward_exact_bf16 | ( | const uint16_t * | input, |
| const uint16_t * | d_output, | ||
| uint16_t * | d_input, | ||
| size_t | n, | ||
| float * | scratch_input, | ||
| float * | scratch_d_output, | ||
| float * | scratch_d_input | ||
| ) |
Definition at line 46 of file gelu_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), and gelu_backward_scalar().
| void gelu_backward_fast | ( | const float * | input, |
| const float * | d_output, | ||
| float * | d_input, | ||
| size_t | n | ||
| ) |
Definition at line 486 of file gelu_kernels.c.
References __attribute__().
Referenced by gelu_backward_fast_bf16().
| void gelu_backward_fast_bf16 | ( | const uint16_t * | input, |
| const uint16_t * | d_output, | ||
| uint16_t * | d_input, | ||
| size_t | n, | ||
| float * | scratch_input, | ||
| float * | scratch_d_output, | ||
| float * | scratch_d_input | ||
| ) |
Definition at line 69 of file gelu_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), and gelu_backward_fast().
| void gelu_backward_scalar | ( | const float * | input, |
| const float * | d_output, | ||
| float * | d_input, | ||
| size_t | n | ||
| ) |
| void gelu_exact_inplace | ( | float * | data, |
| size_t | n | ||
| ) |
Definition at line 446 of file gelu_kernels.c.
Referenced by gelu_fast_inplace_bf16(), and mlp_token_parallel_exact().
| void gelu_fast_inplace | ( | float * | data, |
| size_t | n | ||
| ) |
GELU activation forward (fast approximation, in-place)
test_gelu.py::TestGELUForward::test_gelu_fast_inplace
test_gelu.py::TestGELUForward::test_gelu_vs_exact
test_parity.py::test_gelu_parity
Fast GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) In-place on contiguous buffer.
After changes: make test && make llamacpp-parity-full
Definition at line 132 of file gelu_kernels.c.
References __attribute__().
Referenced by mlp_token_parallel().
| void gelu_fast_inplace_bf16 | ( | uint16_t * | data, |
| size_t | n, | ||
| float * | scratch | ||
| ) |
Definition at line 31 of file gelu_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), and gelu_exact_inplace().
| void gemm_avx512_parallel | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 149 of file gemm_kernels.c.
References C, ck_strict_parity_enabled(), gemm_naive_parallel(), and gemm_naive_serial_double().
| void gemm_bias_gelu_fused | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 131 of file gemm_fused_kernels.c.
References C, fast_gelu_scalar(), and hsum256_ps_fused().
| void gemm_bias_relu_fused | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
| void gemm_bias_silu_fused | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
| void gemm_blocked_serial | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 661 of file gemm_kernels.c.
References C, ck_gemm_add_bias(), ck_get_num_threads(), ck_min(), ck_strict_parity_enabled(), gemm_microkernel(), gemm_naive_serial_double(), and gemm_nt_matvec_parallel().
Referenced by ck_attention_project_head_major(), ck_gemm_nt_quant(), ck_mlp_swiglu_forward(), ck_mlp_swiglu_forward_fused_token(), ck_qkv_project_head_major(), ck_qkv_project_head_major_token(), mlp_token_parallel(), and mlp_token_parallel_exact().
| void gemm_blocked_serial_bf16 | ( | const uint16_t * | A, |
| const uint16_t * | B, | ||
| const uint16_t * | bias, | ||
| uint16_t * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
| void gemm_fine_grained_parallel | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 205 of file gemm_kernels.c.
References C, ck_min(), ck_strict_parity_enabled(), gemm_naive_parallel(), and gemm_naive_serial_double().
| void gemm_microkernel | ( | const float * | A, |
| const float * | B, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K, | ||
| int | B_transposed | ||
| ) |
Definition at line 1134 of file gemm_microkernel.c.
References C, gemm_microkernel_blocked(), gemm_microkernel_blocked_bt(), gemm_microkernel_packed(), and PACK_THRESHOLD.
Referenced by gemm_blocked_serial().
| void gemm_microkernel_blocked | ( | const float * | A, |
| const float * | B, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 934 of file gemm_microkernel.c.
References C, gemm_init_threads(), gemm_microkernel_edge(), gemm_microkernel_sequential(), KC, MR, and NR.
Referenced by gemm_microkernel(), and gemm_microkernel_packed().
| void gemm_microkernel_blocked_bt | ( | const float * | A, |
| const float * | B, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 1058 of file gemm_microkernel.c.
References C, KC, MC, MR, NC, and NR.
Referenced by gemm_microkernel().
| void gemm_microkernel_packed | ( | const float * | A, |
| const float * | B, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 840 of file gemm_microkernel.c.
References C, and gemm_microkernel_blocked().
Referenced by gemm_microkernel().
| void gemm_naive_parallel | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 125 of file gemm_kernels.c.
References C, ck_strict_parity_enabled(), and gemm_naive_serial_double().
Referenced by ck_attention_project_head_major_ref(), ck_mlp_swiglu_forward_ref(), ck_qkv_project_head_major_ref(), gemm_avx512_parallel(), and gemm_fine_grained_parallel().
| void gemm_nn_avx512 | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 339 of file gemm_kernels.c.
References C, ck_strict_parity_enabled(), gemm_nn_parallel(), and gemm_nn_serial_double().
Referenced by fc1_backward_kernel(), and fc2_backward_kernel().
| void gemm_nn_blocked | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 402 of file gemm_kernels.c.
References C, ck_min(), ck_strict_parity_enabled(), and gemm_nn_serial_double().
| void gemm_nn_parallel | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 317 of file gemm_kernels.c.
References C, ck_strict_parity_enabled(), and gemm_nn_serial_double().
Referenced by gemm_nn_avx512().
| void gemm_nt_q4_0 | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
| A | Input matrix [M x K], row-major FP32 |
| B | Weight matrix in Q4_0 format, [N x K] stored row-major |
| bias | Optional bias [N], NULL if not used |
| C | Output [M x N], row-major FP32 |
| M | Batch size (number of tokens) |
| N | Output dimension (number of rows in B) |
| K | Input dimension |
Definition at line 176 of file gemm_kernels_q4_0.c.
References C, CK_FP16_TO_FP32, block_q4_0::d, QK4_0, and block_q4_0::qs.
Referenced by ck_gemm_nt_quant().
| void gemm_nt_q4_1 | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
GEMM with transposed Q4_1 weights: C = A @ B^T.
| A | Input activations [M x K], row-major FP32 |
| B | Weight matrix in Q4_1 format [N x K], row-major quantized |
| bias | Optional bias [N], NULL if not used |
| C | Output [M x N], row-major FP32 |
| M | Batch size (number of tokens) |
| N | Output dimension |
| K | Input dimension |
Definition at line 256 of file gemm_kernels_q4_1.c.
References C, CK_FP16_TO_FP32, block_q4_1::d, block_q4_1::m, QK4_1, and block_q4_1::qs.
Referenced by ck_gemm_nt_quant().
| void gemm_nt_q4_k | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 683 of file gemm_kernels_q4k.c.
References C, and gemm_q4_k().
Referenced by ck_attention_project_head_major_q4_k(), ck_gemm_nt_quant(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_mlp_swiglu_forward_q4_k(), ck_qkv_project_head_major_q4_k(), ck_qkv_project_head_major_token_q4_k(), model_decode_token(), model_layer_0_decode(), model_layer_0_prefill(), model_layer_10_decode(), model_layer_10_prefill(), model_layer_11_decode(), model_layer_11_prefill(), model_layer_12_decode(), model_layer_12_prefill(), model_layer_13_decode(), model_layer_13_prefill(), model_layer_14_decode(), model_layer_14_prefill(), model_layer_15_decode(), model_layer_15_prefill(), model_layer_16_decode(), model_layer_16_prefill(), model_layer_17_decode(), model_layer_17_prefill(), model_layer_18_decode(), model_layer_18_prefill(), model_layer_19_decode(), model_layer_19_prefill(), model_layer_1_decode(), model_layer_1_prefill(), model_layer_20_decode(), model_layer_20_prefill(), model_layer_21_decode(), model_layer_21_prefill(), model_layer_22_decode(), model_layer_22_prefill(), model_layer_23_decode(), model_layer_23_prefill(), model_layer_2_decode(), model_layer_2_prefill(), model_layer_3_decode(), model_layer_3_prefill(), model_layer_4_decode(), model_layer_4_prefill(), model_layer_5_decode(), model_layer_5_prefill(), model_layer_6_decode(), model_layer_6_prefill(), model_layer_7_decode(), model_layer_7_prefill(), model_layer_8_decode(), model_layer_8_prefill(), model_layer_9_decode(), model_layer_9_prefill(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().
| void gemm_nt_q4_k_q8_k | ( | const void * | A_q8, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 295 of file gemm_kernels_q4k_q8k.c.
References C, and gemm_q4_k_q8_k().
Referenced by ck_attention_project_head_major_q4_k_q8_k(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_qkv_project_head_major_token_q4_k_q8_k(), gemm_nt_q8_k_mlp_dispatch(), gemm_nt_q8_k_qkv_dispatch(), model_forward_prefill_impl(), and qwen2_0_5b_decode_forward_prefill_impl().
| void gemm_nt_q5_0 | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 831 of file gemm_kernels_q5_0.c.
References C, gemm_q5_0(), and gemv_q5_0().
Referenced by ck_gemm_nt_quant(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().
| void gemm_nt_q5_1 | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
GEMM with transposed Q5_1 weights: C = A @ B^T.
| A | Input activations [M x K], row-major FP32 |
| B | Weight matrix in Q5_1 format [N x K], row-major quantized |
| bias | Optional bias [N], NULL if not used |
| C | Output [M x N], row-major FP32 |
| M | Batch size (number of tokens) |
| N | Output dimension |
| K | Input dimension |
Definition at line 309 of file gemm_kernels_q5_1.c.
References C, CK_FP16_TO_FP32, block_q5_1::d, block_q5_1::m, block_q5_1::qh, QK5_1, and block_q5_1::qs.
Referenced by ck_gemm_nt_quant().
| void gemm_nt_q5_k | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 218 of file gemm_kernels_q5_k.c.
References C, and gemm_nt_q5_k_ref().
| void gemm_nt_q6_k | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 212 of file gemm_kernels_q6k.c.
References C, and gemm_q6_k().
Referenced by ck_gemm_nt_quant(), gemm_nt_q6_k_ref(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().
| void gemm_nt_q6_k_q8_k | ( | const void * | A_q8, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.
This is the typical inference pattern:
| A_q8 | Input activations in Q8_K format |
| B | Weight matrix in Q6_K format |
| bias | Optional bias vector [N] |
| C | Output matrix |
| M | Batch size (number of tokens) |
| N | Output dimension |
| K | Input dimension |
Definition at line 1144 of file gemm_kernels_q6k_q8k.c.
References C, and gemm_q6_k_q8_k().
Referenced by gemm_nt_q8_k_mlp_dispatch(), and gemm_nt_q8_k_qkv_dispatch().
| void gemm_nt_q8_0 | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
| A | Input matrix [M x K], row-major FP32 |
| B | Weight matrix in Q8_0 format, [N x K] stored row-major |
| bias | Optional bias [N], NULL if not used |
| C | Output [M x N], row-major FP32 |
| M | Batch size (number of tokens) |
| N | Output dimension (number of rows in B) |
| K | Input dimension |
Definition at line 681 of file gemm_kernels_q8_0.c.
References C, CK_FP16_TO_FP32, block_q8_0::d, gemv_q8_0(), QK8_0, and block_q8_0::qs.
Referenced by ck_gemm_nt_quant(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_forward_prefill_impl(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().
| void gemm_nt_q8_0_q8_0 | ( | const void * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
C[m,n] = A[m,K] @ B[n,K]^T + bias[n]
Definition at line 582 of file gemm_batch_int8.c.
References C, and gemm_nt_q8_0_q8_0_ref().
Referenced by gemm_nt_q8_0_dispatch(), and gemm_nt_q8_0_mlp_dispatch().
| void gemm_q4_k | ( | float * | Y, |
| const void * | W, | ||
| const float * | X, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Auto-dispatch GEMM based on available SIMD.
Definition at line 461 of file gemm_kernels_q4k.c.
References gemm_q4_k_ref().
Referenced by gemm_nt_q4_k().
| void gemm_q4_k_q8_k | ( | float * | Y, |
| const void * | W, | ||
| const void * | X_q8, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 277 of file gemm_kernels_q4k_q8k.c.
References gemv_q4_k_q8_k(), and QK_K.
Referenced by gemm_nt_q4_k_q8_k().
| void gemm_q6_k | ( | float * | Y, |
| const void * | W, | ||
| const float * | X, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 195 of file gemm_kernels_q6k.c.
References gemv_q6_k().
Referenced by gemm_nt_q6_k().
| void gemm_q6_k_q8_k | ( | float * | Y, |
| const void * | W, | ||
| const void * | X_q8, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K.
| Y | Output matrix [N x M] in row-major |
| W | Weight matrix in Q6_K format [M x K] |
| X_q8 | Input matrix in Q8_K format [N x K] |
| M | Number of output rows (output dim) |
| N | Number of input vectors (batch size) |
| K | Input dimension |
Definition at line 1110 of file gemm_kernels_q6k_q8k.c.
References gemv_q6_k_q8_k(), and QK_K.
Referenced by gemm_nt_q6_k_q8_k().
| void gemm_swiglu_fused | ( | const float * | x, |
| const float * | W_gate, | ||
| const float * | W_up, | ||
| const float * | b_gate, | ||
| const float * | b_up, | ||
| float * | output, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 241 of file gemm_fused_kernels.c.
References hsum256_ps_fused().
Referenced by ck_mlp_swiglu_forward_fused_token().
| void gemm_tn_avx512 | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 521 of file gemm_kernels.c.
References C, ck_strict_parity_enabled(), gemm_tn_parallel(), and gemm_tn_serial_double().
Referenced by fc1_backward_kernel(), and fc2_backward_kernel().
| void gemm_tn_blocked | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 581 of file gemm_kernels.c.
References C, ck_min(), ck_strict_parity_enabled(), and gemm_tn_serial_double().
| void gemm_tn_parallel | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 499 of file gemm_kernels.c.
References C, ck_strict_parity_enabled(), and gemm_tn_serial_double().
Referenced by gemm_tn_avx512().
| void gemv_fused_q5_0_bias_dispatch | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| const float * | bias, | ||
| int | M, | ||
| int | K | ||
| ) |
Definition at line 508 of file gemv_fused_quant_bias.c.
References gemv_fused_q5_0_bias().
| void gemv_fused_q8_0_bias_dispatch | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| const float * | bias, | ||
| int | M, | ||
| int | K | ||
| ) |
Definition at line 523 of file gemv_fused_quant_bias.c.
References gemv_fused_q8_0_bias().
| void gemv_q4_0 | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch GEMV.
Definition at line 132 of file gemm_kernels_q4_0.c.
References gemv_q4_0_ref().
Referenced by dot_q4_0(), and gemm_q4_0().
| void gemv_q4_k | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch GEMV based on available SIMD.
Definition at line 285 of file gemm_kernels_q4k.c.
References gemv_q4_k_ref().
Referenced by attention_mlp_fused_q4k(), dot_q4_k(), gemm_q4_k_ref(), layer_fused_attn_mlp_qkv_q4k(), and rmsnorm_qkv_q4k_fused().
| void gemv_q4_k_q8_k | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K | ||
| ) |
Definition at line 239 of file gemm_kernels_q4k_q8k.c.
References gemv_q4_k_q8_k_avx(), gemv_q4_k_q8_k_avx2(), gemv_q4_k_q8_k_ref(), gemv_q4_k_q8_k_sse(), and gemv_q4_k_q8_k_vnni().
Referenced by model_decode_token(), model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().
| void gemv_q4_k_q8_k_parallel | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K, | ||
| int | ith, | ||
| int | nth | ||
| ) |
Definition at line 206 of file gemm_kernels_q4k_q8k.c.
References dot_q4_k_q8_k_ref(), and QK_K.
| void gemv_q4_k_q8_k_parallel_simd | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K, | ||
| int | ith, | ||
| int | nth | ||
| ) |
Definition at line 263 of file gemm_kernels_q4k_avx.c.
References gemv_q4_k_q8_k_parallel().
Referenced by decode_layer_parallel(), mlp_parallel(), and qkv_projection_parallel().
| void gemv_q4_k_q8_k_ref | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K | ||
| ) |
| void gemv_q5_0 | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch GEMV for Q5_0 weights based on CPU features.
Dispatch priority (best available):
Uses ck_features.h for standardized feature detection.
| y | Output vector [M] |
| W | Weight matrix in Q5_0 format [M x K] |
| x | Input vector [K] |
| M | Number of output rows |
| K | Number of input columns (hidden dimension) |
Definition at line 547 of file gemm_kernels_q5_0.c.
References gemv_q5_0_ref().
| void gemv_q5_0_parallel | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K, | ||
| int | ith, | ||
| int | nth | ||
| ) |
Parallel reference GEMV for Q5_0 × FP32.
Definition at line 576 of file gemm_kernels_q5_0.c.
References CK_FP16_TO_FP32, block_q5_0::d, block_q5_0::qh, QK5_0, and block_q5_0::qs.
Referenced by gemv_q5_0_parallel_simd().
| void gemv_q5_0_parallel_simd | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K, | ||
| int | ith, | ||
| int | nth | ||
| ) |
Parallel SIMD GEMV for Q5_0 × FP32 with prefetching.
Definition at line 622 of file gemm_kernels_q5_0.c.
References gemv_q5_0_parallel(), and QK5_0.
| void gemv_q5_0_q8_0 | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K | ||
| ) |
Matrix-vector multiply with Q5_0 weights and Q8_0 input.
| y | Output vector [M] |
| W | Weight matrix in Q5_0 format [M x K] |
| x_q8 | Input vector in Q8_0 format [K] |
| M | Number of output rows |
| K | Number of columns (must be multiple of 32) |
Definition at line 1529 of file gemm_kernels_q5_0.c.
References QK5_0, and vec_dot_q5_0_q8_0().
| void gemv_q5_1 | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch GEMV.
Definition at line 184 of file gemm_kernels_q5_1.c.
References gemv_q5_1_ref().
| void gemv_q5_k | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Definition at line 199 of file gemm_kernels_q5_k.c.
References gemv_q5_k_ref().
| void gemv_q6_k | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Definition at line 169 of file gemm_kernels_q6k.c.
References dot_q6_k_ref(), and QK_K.
Referenced by gemm_q6_k().
| void gemv_q6_k_q8_k | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K | ||
| ) |
GEMV: y = W @ x where W is Q6_K and x is Q8_K.
Definition at line 980 of file gemm_kernels_q6k_q8k.c.
References gemv_q6_k_q8_k_avx(), gemv_q6_k_q8_k_avx2(), gemv_q6_k_q8_k_avx512(), gemv_q6_k_q8_k_ref(), and gemv_q6_k_q8_k_sse().
| void gemv_q6_k_q8_k_parallel | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K, | ||
| int | ith, | ||
| int | nth | ||
| ) |
Parallel reference GEMV for Q6_K × Q8_K.
Caller provides ith (thread index) and nth (total threads). Each thread processes rows [r0, r1).
Definition at line 1014 of file gemm_kernels_q6k_q8k.c.
References dot_q6_k_q8_k_ref(), and QK_K.
| void gemv_q6_k_q8_k_parallel_simd | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K, | ||
| int | ith, | ||
| int | nth | ||
| ) |
Parallel SIMD GEMV for Q6_K × Q8_K.
Uses best available SIMD (AVX/SSE) with row prefetching. Caller provides ith/nth from OpenMP region.
Definition at line 1046 of file gemm_kernels_q6k_q8k.c.
References dot_q6_k_q8_k_ref(), and QK_K.
| void gemv_q8_0 | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch GEMV for Q8_0 weights based on CPU features.
Dispatch priority (best available):
Uses ck_features.h for standardized feature detection.
| y | Output vector [M] |
| W | Weight matrix in Q8_0 format [M x K] |
| x | Input vector [K] |
| M | Number of output rows |
| K | Number of input columns (hidden dimension) |
Definition at line 630 of file gemm_kernels_q8_0.c.
References gemv_q8_0_ref().
| void gemv_q8_0_q8_0 | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K | ||
| ) |
Matrix-vector multiply with Q8_0 weights and Q8_0 input.
| y | Output vector [M] |
| W | Weight matrix in Q8_0 format [M x K] |
| x_q8 | Input vector in Q8_0 format [K] |
| M | Number of output rows |
| K | Number of columns (must be multiple of 32) |
Definition at line 1042 of file gemm_kernels_q8_0.c.
References QK8_0, and vec_dot_q8_0_q8_0().
| void im2patch | ( | const float * | image, |
| float * | patches, | ||
| int | C, | ||
| int | H, | ||
| int | W, | ||
| int | P | ||
| ) |
im2patch: Transforms an image into a sequence of flattened patches.
Image Layout: [C, H, W] (Row-major: W is fastest moving) Output Layout: [num_patches, C * P * P]
num_patches = (H/P) * (W/P) P = patch_size
Definition at line 28 of file vision_kernels.c.
References C.
| void im2patch_bf16 | ( | const uint16_t * | image, |
| uint16_t * | patches, | ||
| int | C, | ||
| int | H, | ||
| int | W, | ||
| int | P | ||
| ) |
Definition at line 22 of file vision_kernels_bf16.c.
References C.
| void kv_cache_repack_head_major_inplace | ( | float * | buf, |
| int | num_heads, | ||
| int | tokens, | ||
| int | cache_capacity, | ||
| int | aligned_head_dim | ||
| ) |
Definition at line 28 of file kv_cache_kernels.c.
Referenced by qwen2_0_5b_decode_forward_prefill_impl().
| void kv_cache_store | ( | float *__restrict | kv_cache_k, |
| float *__restrict | kv_cache_v, | ||
| const float *__restrict | k, | ||
| const float *__restrict | v, | ||
| int | layer, | ||
| int | pos, | ||
| int | num_kv_heads, | ||
| int | head_dim, | ||
| int | max_seq_len | ||
| ) |
Definition at line 101 of file kv_cache_kernels.c.
References kv_cache_write_head_major().
| void kv_cache_write_head_major | ( | const float *__restrict | k_token, |
| const float *__restrict | v_token, | ||
| float *__restrict | k_cache, | ||
| float *__restrict | v_cache, | ||
| int | num_kv_heads, | ||
| int | token_index, | ||
| int | cache_capacity, | ||
| int | head_dim, | ||
| int | aligned_head_dim | ||
| ) |
Definition at line 60 of file kv_cache_kernels.c.
Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), kv_cache_store(), mega_fused_attention_decode(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().
| void layernorm_backward_kernel | ( | const float * | d_output, |
| const float * | input, | ||
| const float * | gamma, | ||
| const float * | mean, | ||
| const float * | rstd, | ||
| float * | d_input, | ||
| float * | d_gamma, | ||
| float * | d_beta, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim | ||
| ) |
| void layernorm_backward_kernel_bf16 | ( | const uint16_t * | d_output, |
| const uint16_t * | input, | ||
| const float * | gamma, | ||
| const float * | mean, | ||
| const float * | rstd, | ||
| uint16_t * | d_input, | ||
| float * | d_gamma, | ||
| float * | d_beta, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float * | scratch_d_output, | ||
| float * | scratch_input, | ||
| float * | scratch_d_input | ||
| ) |
Definition at line 84 of file layernorm_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), and layernorm_backward_kernel().
| void layernorm_forward_rolled_slice | ( | const float *__restrict | input_slice_base, |
| const float *__restrict | gamma, | ||
| const float *__restrict | beta, | ||
| float *__restrict | output_slice_base, | ||
| float *__restrict | mean_cache_slice, | ||
| float *__restrict | rstd_cache_slice, | ||
| int | num_tokens_in_slice, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps | ||
| ) |
Definition at line 274 of file layernorm_kernels.c.
References layernorm_naive_serial().
Referenced by layernorm_forward_rolled_slice_bf16().
| void layernorm_forward_rolled_slice_bf16 | ( | const uint16_t *__restrict | input_slice_base, |
| const float *__restrict | gamma, | ||
| const float *__restrict | beta, | ||
| uint16_t *__restrict | output_slice_base, | ||
| float *__restrict | mean_cache_slice, | ||
| float *__restrict | rstd_cache_slice, | ||
| int | num_tokens_in_slice, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps, | ||
| float * | scratch_input, | ||
| float * | scratch_output | ||
| ) |
Definition at line 30 of file layernorm_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), and layernorm_forward_rolled_slice().
| void layernorm_forward_unrolled_slice | ( | const float *__restrict | input_slice_base, |
| const float *__restrict | gamma, | ||
| const float *__restrict | beta, | ||
| float *__restrict | output_slice_base, | ||
| float *__restrict | mean_cache_slice, | ||
| float *__restrict | rstd_cache_slice, | ||
| int | num_tokens_in_slice, | ||
| int | d_model, | ||
| float | eps | ||
| ) |
Definition at line 598 of file layernorm_kernels.c.
References layernorm_forward_unrolled_slice_scalar().
Referenced by layernorm_forward_unrolled_slice_bf16().
| void layernorm_forward_unrolled_slice_bf16 | ( | const uint16_t *__restrict | input_slice_base, |
| const float *__restrict | gamma, | ||
| const float *__restrict | beta, | ||
| uint16_t *__restrict | output_slice_base, | ||
| float *__restrict | mean_cache_slice, | ||
| float *__restrict | rstd_cache_slice, | ||
| int | num_tokens_in_slice, | ||
| int | d_model, | ||
| float | eps, | ||
| float * | scratch_input, | ||
| float * | scratch_output | ||
| ) |
Definition at line 57 of file layernorm_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), and layernorm_forward_unrolled_slice().
| void layernorm_naive_serial | ( | const float * | input, |
| const float * | gamma, | ||
| const float * | beta, | ||
| float * | output, | ||
| float * | mean_cache, | ||
| float * | rstd_cache, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps | ||
| ) |
| void layernorm_naive_serial_matched_precision | ( | const float * | input, |
| const float * | gamma, | ||
| const float * | beta, | ||
| float * | output, | ||
| float * | mean_cache, | ||
| float * | rstd_cache, | ||
| int | tokens, | ||
| int | d_model, | ||
| float | eps | ||
| ) |
Definition at line 624 of file layernorm_kernels.c.
Referenced by layernorm_forward_unrolled_slice_scalar().
| void mlp_token_parallel | ( | const float * | input, |
| const float * | W_fc1, | ||
| const float * | b_fc1, | ||
| const float * | W_fc2, | ||
| const float * | b_fc2, | ||
| float * | fc1_output, | ||
| float * | output, | ||
| int | T, | ||
| int | aligned_dim, | ||
| int | num_threads | ||
| ) |
Definition at line 41 of file mlp_kernels.c.
References gelu_fast_inplace(), and gemm_blocked_serial().
| void mlp_token_parallel_bf16 | ( | const uint16_t * | input, |
| const uint16_t * | W_fc1, | ||
| const uint16_t * | b_fc1, | ||
| const uint16_t * | W_fc2, | ||
| const uint16_t * | b_fc2, | ||
| float * | fc1_output, | ||
| float * | output, | ||
| int | T, | ||
| int | aligned_dim, | ||
| int | num_threads, | ||
| float * | scratch_bias1_f, | ||
| float * | scratch_bias2_f, | ||
| uint16_t * | scratch_fc1_bf16 | ||
| ) |
Optimized MLP Forward (BF16 weights, FP32 activations)
Caller-provided scratch buffers: scratch_bias1_f: [4*D] floats scratch_bias2_f: [D] floats scratch_fc1_bf16: [T * 4*D] uint16_t (BF16)
Definition at line 91 of file mlp_kernels_bf16.c.
References bf16_to_float(), float_to_bf16(), gelu_scalar(), and gemm_bf16_fp32out().
| void mlp_token_parallel_bf16_fp32act | ( | const uint16_t * | input, |
| const uint16_t * | W_fc1, | ||
| const uint16_t * | b_fc1, | ||
| const uint16_t * | W_fc2, | ||
| const uint16_t * | b_fc2, | ||
| float * | fc1_output, | ||
| float * | output, | ||
| int | T, | ||
| int | aligned_dim, | ||
| int | num_threads, | ||
| float * | scratch_input_f, | ||
| float * | scratch_bias1_f, | ||
| float * | scratch_bias2_f, | ||
| uint16_t * | scratch_fc1_bf16 | ||
| ) |
Alternative: Fully FP32 activations throughout
Caller-provided scratch buffers: scratch_input_f: [T * D] floats scratch_bias1_f: [4*D] floats scratch_bias2_f: [D] floats scratch_fc1_bf16: [T * 4*D] uint16_t (BF16)
Definition at line 186 of file mlp_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), gelu_scalar(), and gemm_bf16_fp32out().
| void mlp_token_parallel_exact | ( | const float * | input, |
| const float * | W_fc1, | ||
| const float * | b_fc1, | ||
| const float * | W_fc2, | ||
| const float * | b_fc2, | ||
| float * | fc1_output, | ||
| float * | output, | ||
| int | T, | ||
| int | aligned_dim, | ||
| int | num_threads | ||
| ) |
Definition at line 76 of file mlp_kernels.c.
References gelu_exact_inplace(), and gemm_blocked_serial().
| void moe_accumulate_expert_f32 | ( | float * | output, |
| const float * | expert_output, | ||
| float | routing_weight, | ||
| int | hidden_dim | ||
| ) |
Accumulate expert output: output += routing_weight * expert_output.
| output | Token output buffer [hidden_dim], accumulated in place |
| expert_output | Expert's output for this token [hidden_dim] |
| routing_weight | Softmax routing weight for this expert |
| hidden_dim | Hidden dimension |
Definition at line 256 of file axpy_kernels.c.
References axpy_f32().
| void patch2im | ( | const float * | d_patches, |
| float * | d_image, | ||
| int | C, | ||
| int | H, | ||
| int | W, | ||
| int | P | ||
| ) |
patch2im: Accumulates gradients from patches back into the image. (Backward pass)
d_patches: [num_patches, C * P * P] d_image: [C, H, W] (Accumulated)
Definition at line 69 of file vision_kernels.c.
References C.
| void patch2im_bf16 | ( | const uint16_t * | d_patches, |
| uint16_t * | d_image, | ||
| int | C, | ||
| int | H, | ||
| int | W, | ||
| int | P | ||
| ) |
Definition at line 57 of file vision_kernels_bf16.c.
References bf16_to_float(), C, and float_to_bf16().
| void quantize_batch_q8_0 | ( | const float * | x, |
| void * | vy, | ||
| int | num_rows, | ||
| int | k | ||
| ) |
Batch quantize FP32 to Q8_0 format (row-major output)
Quantizes multiple rows of FP32 data to Q8_0 format, placing each row's Q8_0 output at the correct byte offset for GEMM compatibility.
Memory layout: Input: [num_rows, k] FP32, row-major (stride = k * sizeof(float)) Output: [num_rows, q8_row_bytes] Q8_0, row-major (stride = q8_row_bytes)
where q8_row_bytes = (k / 32) * sizeof(block_q8_0) = (k / 32) * 34
| x | Input FP32 values [num_rows * k] |
| vy | Output Q8_0 blocks [num_rows * (k/32) blocks] |
| num_rows | Number of rows (batch size / tokens) |
| k | Elements per row (must be multiple of 32) |
Definition at line 192 of file gemm_kernels_q8_0.c.
References QK8_0, and quantize_row_q8_0().
| void quantize_batch_q8_k | ( | const float * | x, |
| void * | vy, | ||
| int | num_rows, | ||
| int | k | ||
| ) |
Batch quantize FP32 to Q8_K format (row-major output)
Same as quantize_batch_q8_0 but for Q8_K format (super-blocks).
| x | Input FP32 values [num_rows * k] |
| vy | Output Q8_K blocks |
| num_rows | Number of rows (batch size / tokens) |
| k | Elements per row (must be multiple of 256) |
Definition at line 219 of file gemm_kernels_q8_0.c.
References quantize_row_q8_k().
| void quantize_row_q8_0 | ( | const float * | x, |
| void * | vy, | ||
| int | k | ||
| ) |
Quantize FP32 to Q8_0 format (scalar reference)
| x | Input FP32 values |
| vy | Output Q8_0 blocks |
| k | Number of elements (must be multiple of 32) |
Definition at line 59 of file gemm_kernels_q8_0.c.
References CK_FP32_TO_FP16, block_q8_0::d, id, QK8_0, and block_q8_0::qs.
Referenced by fused_mlp_swiglu_prefill_w1w2_quant(), fused_rmsnorm_qkv_prefill_head_major_quant(), and quantize_attn_out_head_major_q8_0().
| void quantize_row_q8_k | ( | const float * | x, |
| void * | y, | ||
| int | k | ||
| ) |
Definition at line 107 of file gemm_kernels_q4k_q8k.c.
References quantize_row_q8_k_ref(), and quantize_row_q8_k_sse().
Referenced by ck_attention_project_head_major_q4_k_q8_k(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_qkv_project_head_major_q4_k_q8_k(), decode_layer_parallel(), fused_mlp_swiglu_prefill_w1w2_quant(), fused_rmsnorm_qkv_prefill_head_major_quant(), mlp_parallel(), model_decode_token(), model_forward_prefill_impl(), model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_forward_prefill_impl(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_9_decode(), and unfused_rmsnorm_linear_q4k_ref().
| void relu_backward | ( | const float * | input, |
| const float * | d_output, | ||
| float * | d_input, | ||
| size_t | n | ||
| ) |
| void relu_backward_bf16 | ( | const uint16_t * | input, |
| const uint16_t * | d_output, | ||
| uint16_t * | d_input, | ||
| size_t | n | ||
| ) |
| void relu_forward | ( | const float * | input, |
| float * | output, | ||
| size_t | n | ||
| ) |
Definition at line 26 of file relu_kernels.c.
| void relu_forward_bf16 | ( | const uint16_t * | input, |
| uint16_t * | output, | ||
| size_t | n | ||
| ) |
| void relu_forward_inplace | ( | float * | data, |
| size_t | n | ||
| ) |
Definition at line 54 of file relu_kernels.c.
| void relu_forward_inplace_bf16 | ( | uint16_t * | data, |
| size_t | n | ||
| ) |
| void rmsnorm_backward | ( | const float * | d_output, |
| const float * | input, | ||
| const float * | gamma, | ||
| const float * | rstd_cache, | ||
| float * | d_input, | ||
| float * | d_gamma, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim | ||
| ) |
RMSNorm backward pass
test_rmsnorm.py::TestRMSNormBackward::test_backward_tokens
test_rmsnorm.py::TestRMSNormBackward::test_backward_single
test_parity.py::test_rmsnorm_backward_parity
Computes dX and dGamma given dY, X, gamma, and cached rstd. dX_i = rstd * (dY_i * gamma_i - x_hat_i * m) dGamma_i = sum_t (dY_i * x_hat_i)
After changes: make test && make llamacpp-parity-full
Definition at line 184 of file rmsnorm_kernels.c.
Referenced by ck_layer_backward_rmsnorm_swiglu(), rmsnorm_backward_int4(), and rmsnorm_backward_int8().
| void rmsnorm_backward_bf16 | ( | const uint16_t * | d_output, |
| const uint16_t * | input, | ||
| const float * | gamma, | ||
| const float * | rstd_cache, | ||
| uint16_t * | d_input, | ||
| float * | d_gamma, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim | ||
| ) |
Definition at line 113 of file rmsnorm_kernels_bf16.c.
References bf16_to_float(), and float_to_bf16().
| void rmsnorm_backward_int4 | ( | const uint8_t * | d_output, |
| const uint8_t * | input, | ||
| const float * | gamma, | ||
| const float * | rstd_cache, | ||
| uint8_t * | d_input, | ||
| float * | d_gamma, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float * | scratch_d_output, | ||
| float * | scratch_input, | ||
| float * | scratch_d_input | ||
| ) |
Definition at line 104 of file rmsnorm_kernels_int4.c.
References convert_float_to_int4(), convert_int4_to_float(), and rmsnorm_backward().
| void rmsnorm_backward_int8 | ( | const int8_t * | d_output, |
| const int8_t * | input, | ||
| const float * | gamma, | ||
| const float * | rstd_cache, | ||
| int8_t * | d_input, | ||
| float * | d_gamma, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float * | scratch_d_output, | ||
| float * | scratch_input, | ||
| float * | scratch_d_input | ||
| ) |
Definition at line 84 of file rmsnorm_kernels_int8.c.
References convert_float_to_int8(), convert_int8_to_float(), and rmsnorm_backward().
| void rmsnorm_forward | ( | const float * | input, |
| const float * | gamma, | ||
| float * | output, | ||
| float * | rstd_cache, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps | ||
| ) |
RMSNorm forward pass
test_rmsnorm.py::TestRMSNormForward::test_fp32_tokens
test_rmsnorm.py::TestRMSNormForward::test_fp32_single
test_rmsnorm.py::TestRMSNormForward::test_perf_rolled
test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat
test_parity.py::test_rmsnorm_parity
RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)
After changes: make test && make llamacpp-parity-full
Definition at line 50 of file rmsnorm_kernels.c.
Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), ck_layer_forward_rmsnorm_swiglu_ref(), mega_fused_outproj_mlp_prefill(), model_decode_token(), model_forward_prefill_impl(), model_layer_0_decode(), model_layer_0_prefill(), model_layer_10_decode(), model_layer_10_prefill(), model_layer_11_decode(), model_layer_11_prefill(), model_layer_12_decode(), model_layer_12_prefill(), model_layer_13_decode(), model_layer_13_prefill(), model_layer_14_decode(), model_layer_14_prefill(), model_layer_15_decode(), model_layer_15_prefill(), model_layer_16_decode(), model_layer_16_prefill(), model_layer_17_decode(), model_layer_17_prefill(), model_layer_18_decode(), model_layer_18_prefill(), model_layer_19_decode(), model_layer_19_prefill(), model_layer_1_decode(), model_layer_1_prefill(), model_layer_20_decode(), model_layer_20_prefill(), model_layer_21_decode(), model_layer_21_prefill(), model_layer_22_decode(), model_layer_22_prefill(), model_layer_23_decode(), model_layer_23_prefill(), model_layer_2_decode(), model_layer_2_prefill(), model_layer_3_decode(), model_layer_3_prefill(), model_layer_4_decode(), model_layer_4_prefill(), model_layer_5_decode(), model_layer_5_prefill(), model_layer_6_decode(), model_layer_6_prefill(), model_layer_7_decode(), model_layer_7_prefill(), model_layer_8_decode(), model_layer_8_prefill(), model_layer_9_decode(), model_layer_9_prefill(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_forward_prefill_impl(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), qwen2_0_5b_decode_layer_9_prefill(), rmsnorm_forward_int4(), and rmsnorm_forward_int8().
| void rmsnorm_forward_bf16 | ( | const uint16_t * | input, |
| const float * | gamma, | ||
| uint16_t * | output, | ||
| float * | rstd_cache, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps | ||
| ) |
Definition at line 24 of file rmsnorm_kernels_bf16.c.
References bf16_to_float(), and float_to_bf16().
| void rmsnorm_forward_int4 | ( | const uint8_t * | input, |
| const float * | gamma, | ||
| uint8_t * | output, | ||
| float * | rstd_cache, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps, | ||
| float * | scratch_input, | ||
| float * | scratch_output | ||
| ) |
Definition at line 78 of file rmsnorm_kernels_int4.c.
References convert_float_to_int4(), convert_int4_to_float(), and rmsnorm_forward().
| void rmsnorm_forward_int8 | ( | const int8_t * | input, |
| const float * | gamma, | ||
| int8_t * | output, | ||
| float * | rstd_cache, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps, | ||
| float * | scratch_input, | ||
| float * | scratch_output | ||
| ) |
Definition at line 58 of file rmsnorm_kernels_int8.c.
References convert_float_to_int8(), convert_int8_to_float(), and rmsnorm_forward().
| void rope_backward | ( | const float * | d_out, |
| float * | d_x, | ||
| const float * | cos_cache, | ||
| const float * | sin_cache, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | pos_offset | ||
| ) |
RoPE backward (inverse rotation)
test_rope.py::TestRoPEBackward::test_rope_backward
test_rope.py::TestRoPEBackward::test_rope_backward_vs_separate
RoPE backward: inverse rotation (rotate by -θ). Since cos(-θ) = cos(θ) and sin(-θ) = -sin(θ): d_x[2i] = d0 * c + d1 * s d_x[2i+1] = -d0 * s + d1 * c
After changes: make test
Definition at line 238 of file rope_kernels.c.
Referenced by rope_backward_bf16(), and rope_backward_qk().
| void rope_backward_bf16 | ( | const uint16_t * | d_out, |
| uint16_t * | d_x, | ||
| const float * | cos_cache, | ||
| const float * | sin_cache, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | pos_offset, | ||
| float * | scratch_d_out, | ||
| float * | scratch_d_x | ||
| ) |
Definition at line 52 of file rope_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), and rope_backward().
Referenced by rope_backward_qk_bf16().
| void rope_backward_inplace | ( | float * | d_x, |
| const float * | cos_cache, | ||
| const float * | sin_cache, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | pos_offset | ||
| ) |
RoPE backward in-place (overwrite with inverse rotation)
In-place backward: overwrite d_out with inverse-rotated gradients. Useful when d_x == d_out is acceptable (saves memory).
After changes: make test
Definition at line 345 of file rope_kernels.c.
| void rope_backward_qk | ( | const float * | d_q_out, |
| const float * | d_k_out, | ||
| float * | d_q, | ||
| float * | d_k, | ||
| const float * | cos_cache, | ||
| const float * | sin_cache, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | pos_offset | ||
| ) |
RoPE backward for both dQ and dK
Combined RoPE backward for both dQ and dK gradients.
After changes: make test
Definition at line 497 of file rope_kernels.c.
References rope_backward().
Referenced by ck_layer_backward_rmsnorm_swiglu().
| void rope_backward_qk_bf16 | ( | const uint16_t * | d_q_out, |
| const uint16_t * | d_k_out, | ||
| uint16_t * | d_q, | ||
| uint16_t * | d_k, | ||
| const float * | cos_cache, | ||
| const float * | sin_cache, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | pos_offset, | ||
| float * | scratch_dq_out, | ||
| float * | scratch_dq, | ||
| float * | scratch_dk_out, | ||
| float * | scratch_dk | ||
| ) |
Definition at line 103 of file rope_kernels_bf16.c.
References rope_backward_bf16().
| void rope_forward | ( | float * | x, |
| const float * | cos_cache, | ||
| const float * | sin_cache, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | pos_offset | ||
| ) |
RoPE forward (head-major layout, in-place)
test_rope.py::TestRoPEForward::test_rope_forward
test_rope.py::TestRoPEForward::test_rope_vs_separate
test_parity.py::test_rope_parity
Applies rotary position embeddings in-place to Q or K tensor. x: [num_heads, num_tokens, head_dim] head-major
After changes: make test && make llamacpp-parity-full
Definition at line 180 of file rope_kernels.c.
References rope_apply_head().
Referenced by model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_9_decode(), rope_forward_bf16(), and rope_forward_qk().
| void rope_forward_bf16 | ( | uint16_t * | x, |
| const float * | cos_cache, | ||
| const float * | sin_cache, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | pos_offset, | ||
| float * | scratch | ||
| ) |
Definition at line 28 of file rope_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), and rope_forward().
Referenced by rope_forward_qk_bf16().
| void rope_forward_qk | ( | float * | q, |
| float * | k, | ||
| const float * | cos_cache, | ||
| const float * | sin_cache, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | pos_offset | ||
| ) |
RoPE forward for both Q and K (common inference pattern)
test_rope.py::TestRoPEForward::test_rope_forward_qk
test_fused_attention_decode.py::TestFusedAttentionDecode::test_qk_rope
test_parity.py::test_rope_qk_parity
Combined RoPE forward for both Q and K in one call. q: [num_heads, num_tokens, head_dim] k: [num_kv_heads, num_tokens, head_dim]
After changes: make test && make llamacpp-parity-full
Definition at line 448 of file rope_kernels.c.
References rope_forward().
Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), ck_layer_forward_rmsnorm_swiglu_ref(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().
| void rope_forward_qk_bf16 | ( | uint16_t * | q, |
| uint16_t * | k, | ||
| const float * | cos_cache, | ||
| const float * | sin_cache, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | pos_offset, | ||
| float * | scratch_q, | ||
| float * | scratch_k | ||
| ) |
Definition at line 79 of file rope_kernels_bf16.c.
References rope_forward_bf16().
| void rope_forward_qk_strided | ( | float * | q, |
| float * | k, | ||
| const float * | cos_cache, | ||
| const float * | sin_cache, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | pos_offset, | ||
| int | q_stride_tokens, | ||
| int | k_stride_tokens | ||
| ) |
RoPE forward for both Q and K with custom strides (KV cache layouts)
test_rope.py::TestRoPEForward::test_rope_forward_qk_strided
test_kv_cache_attention.py::TestKVCacheAttention::test_qk_rope_strided
Combined QK RoPE with configurable strides for KV cache layouts.
After changes: make test
Definition at line 472 of file rope_kernels.c.
References rope_forward_strided().
Referenced by mega_fused_attention_prefill(), mega_fused_attention_prefill_q8_0(), model_layer_0_prefill(), model_layer_10_prefill(), model_layer_11_prefill(), model_layer_12_prefill(), model_layer_13_prefill(), model_layer_14_prefill(), model_layer_15_prefill(), model_layer_16_prefill(), model_layer_17_prefill(), model_layer_18_prefill(), model_layer_19_prefill(), model_layer_1_prefill(), model_layer_20_prefill(), model_layer_21_prefill(), model_layer_22_prefill(), model_layer_23_prefill(), model_layer_2_prefill(), model_layer_3_prefill(), model_layer_4_prefill(), model_layer_5_prefill(), model_layer_6_prefill(), model_layer_7_prefill(), model_layer_8_prefill(), model_layer_9_prefill(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), and qwen2_0_5b_decode_layer_9_prefill().
| void rope_forward_strided | ( | float * | x, |
| const float * | cos_cache, | ||
| const float * | sin_cache, | ||
| int | num_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | pos_offset, | ||
| int | head_stride_tokens | ||
| ) |
RoPE forward with custom head stride (for KV cache layouts)
test_rope.py::TestRoPEForward::test_rope_strided
test_kv_cache_attention.py::TestKVCacheAttention::test_rope_decode
Variant with configurable head_stride_tokens for non-contiguous head layouts.
After changes: make test
Definition at line 207 of file rope_kernels.c.
References rope_apply_head().
Referenced by rope_forward_qk_strided().
| void rope_precompute_cache | ( | float * | cos_cache, |
| float * | sin_cache, | ||
| int | max_seq_len, | ||
| int | head_dim, | ||
| float | base | ||
| ) |
Precompute RoPE cos/sin cache
test_rope.py::TestRoPECache::test_cache_computation
test_rope.py::TestRoPECache::test_cache_values
Precomputes cos(m * theta_i) and sin(m * theta_i) for positions 0..max_seq_len-1. cos_cache, sin_cache: [max_seq_len, head_dim/2]
After changes: make test
Definition at line 52 of file rope_kernels.c.
| void scal_copy_f32 | ( | float * | y, |
| const float * | x, | ||
| float | alpha, | ||
| int | n | ||
| ) |
Scaled copy: y = alpha * x.
| y | Output vector [n] |
| x | Input vector [n] |
| alpha | Scalar multiplier |
| n | Vector length |
Definition at line 105 of file axpy_kernels.c.
Referenced by weighted_sum_f32().
| void sigmoid_backward | ( | const float * | input, |
| const float * | d_output, | ||
| float * | d_input, | ||
| size_t | n | ||
| ) |
Definition at line 138 of file sigmoid_kernels.c.
References sigmoid_scalar().
Referenced by sigmoid_backward_bf16().
| void sigmoid_backward_bf16 | ( | const uint16_t * | input, |
| const uint16_t * | d_output, | ||
| uint16_t * | d_input, | ||
| size_t | n, | ||
| float * | scratch_input, | ||
| float * | scratch_d_output, | ||
| float * | scratch_d_input | ||
| ) |
Definition at line 45 of file sigmoid_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), and sigmoid_backward().
| void sigmoid_forward | ( | const float * | input, |
| float * | output, | ||
| size_t | n | ||
| ) |
Definition at line 122 of file sigmoid_kernels.c.
References sigmoid_scalar().
Referenced by sigmoid_forward_bf16().
| void sigmoid_forward_bf16 | ( | const uint16_t * | input, |
| uint16_t * | output, | ||
| size_t | n, | ||
| float * | scratch_input, | ||
| float * | scratch_output | ||
| ) |
Definition at line 27 of file sigmoid_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), and sigmoid_forward().
| float sigmoid_scalar | ( | float | x | ) |
Definition at line 26 of file sigmoid_kernels.c.
Referenced by sigmoid_backward(), sigmoid_forward(), swiglu_backward(), swiglu_backward_bf16(), swiglu_backward_exact(), swiglu_forward(), swiglu_forward_bf16(), and swiglu_forward_exact().
| void softmax_cross_entropy_loss | ( | const float * | logits, |
| const int32_t * | targets, | ||
| int | tokens, | ||
| int | vocab_size, | ||
| float * | d_logits, | ||
| float * | loss_out | ||
| ) |
Definition at line 21 of file loss_kernels.c.
References vocab_size.
Referenced by softmax_cross_entropy_loss_bf16().
| void softmax_cross_entropy_loss_bf16 | ( | const uint16_t * | logits, |
| const int32_t * | targets, | ||
| int | tokens, | ||
| int | vocab_size, | ||
| uint16_t * | d_logits, | ||
| float * | loss_out, | ||
| float * | scratch_logits, | ||
| float * | scratch_d_logits | ||
| ) |
Definition at line 25 of file loss_kernels_bf16.c.
References bf16_tensor_to_float(), float_tensor_to_bf16(), softmax_cross_entropy_loss(), and vocab_size.
| void swiglu_backward | ( | const float * | input, |
| const float * | d_output, | ||
| float * | d_input, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
SwiGLU backward pass
test_swiglu.py::TestSwiGLUBackward::test_backward_tokens
test_swiglu.py::TestSwiGLUBackward::test_backward_single
test_parity.py::test_swiglu_backward_parity
Computes dGate and dUp given dY. dGate = dy * b * silu'(a), dUp = dy * silu(a)
After changes: make test && make llamacpp-parity-full
Definition at line 215 of file swiglu_kernels.c.
References __attribute__(), sigmoid_scalar(), and silu().
Referenced by ck_layer_backward_rmsnorm_swiglu().
| void swiglu_backward_bf16 | ( | const uint16_t * | input, |
| const uint16_t * | d_output, | ||
| uint16_t * | d_input, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
Definition at line 108 of file swiglu_kernels_bf16.c.
References bf16_to_float(), float_to_bf16(), sigmoid_scalar(), and silu().
| void swiglu_backward_exact | ( | const float * | input, |
| const float * | d_output, | ||
| float * | d_input, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
SwiGLU backward pass (exact version using stdlib sigmoid)
test_swiglu.py::TestSwiGLUBackward::test_exact_vs_fast
test_swiglu.py::TestSwiGLUBackward::test_exact_single
Uses standard library expf for numerical accuracy reference.
After changes: make test
Definition at line 373 of file swiglu_kernels.c.
References sigmoid_scalar(), and silu().
| void swiglu_forward | ( | const float * | input, |
| float * | output, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
SwiGLU forward pass
test_swiglu.py::TestSwiGLUForward::test_forward_tokens
test_swiglu.py::TestSwiGLUForward::test_forward_single
test_mlp.py::TestMLPForward::test_swiglu_mlp
test_fused_swiglu_decode.py::TestFusedSwiGLUDecode::test_fused_swiglu_decode
test_parity.py::test_swiglu_parity
SwiGLU: y = silu(gate) * up where silu(x) = x * sigmoid(x)
After changes: make test && make llamacpp-parity-full
Definition at line 131 of file swiglu_kernels.c.
References __attribute__(), sigmoid_scalar(), and silu().
Referenced by ck_mlp_swiglu_forward(), ck_mlp_swiglu_forward_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_mlp_swiglu_forward_quant(), ck_mlp_swiglu_forward_ref(), model_layer_0_decode(), model_layer_0_prefill(), model_layer_10_decode(), model_layer_10_prefill(), model_layer_11_decode(), model_layer_11_prefill(), model_layer_12_decode(), model_layer_12_prefill(), model_layer_13_decode(), model_layer_13_prefill(), model_layer_14_decode(), model_layer_14_prefill(), model_layer_15_decode(), model_layer_15_prefill(), model_layer_16_decode(), model_layer_16_prefill(), model_layer_17_decode(), model_layer_17_prefill(), model_layer_18_decode(), model_layer_18_prefill(), model_layer_19_decode(), model_layer_19_prefill(), model_layer_1_decode(), model_layer_1_prefill(), model_layer_20_decode(), model_layer_20_prefill(), model_layer_21_decode(), model_layer_21_prefill(), model_layer_22_decode(), model_layer_22_prefill(), model_layer_23_decode(), model_layer_23_prefill(), model_layer_2_decode(), model_layer_2_prefill(), model_layer_3_decode(), model_layer_3_prefill(), model_layer_4_decode(), model_layer_4_prefill(), model_layer_5_decode(), model_layer_5_prefill(), model_layer_6_decode(), model_layer_6_prefill(), model_layer_7_decode(), model_layer_7_prefill(), model_layer_8_decode(), model_layer_8_prefill(), model_layer_9_decode(), model_layer_9_prefill(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().
| void swiglu_forward_bf16 | ( | const uint16_t * | input, |
| uint16_t * | output, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
Definition at line 66 of file swiglu_kernels_bf16.c.
References bf16_to_float(), float_to_bf16(), sigmoid_scalar(), and silu().
| void swiglu_forward_exact | ( | const float * | input, |
| float * | output, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
SwiGLU forward pass (exact version using stdlib sigmoid)
test_swiglu.py::TestSwiGLUForward::test_exact_vs_fast
test_swiglu.py::TestSwiGLUForward::test_exact_single
Uses standard library expf for numerical accuracy reference.
After changes: make test
Definition at line 339 of file swiglu_kernels.c.
References sigmoid_scalar(), and silu().
| void topk_batched_f32 | ( | const float * | scores, |
| int | num_tokens, | ||
| int | n_experts, | ||
| int | k, | ||
| int * | indices, | ||
| float * | weights | ||
| ) |
Batched top-K selection for multiple tokens.
| scores | Input scores [num_tokens, n_experts] |
| num_tokens | Number of tokens |
| n_experts | Number of experts |
| k | Number of experts to select per token |
| indices | Output: selected expert indices [num_tokens, k] |
| weights | Output: routing weights [num_tokens, k] (can be NULL for no softmax) |
Definition at line 191 of file topk_kernels.c.
References topk_f32(), and topk_softmax_f32().
| void topk_f32 | ( | const float * | scores, |
| int | n, | ||
| int | k, | ||
| int * | indices, | ||
| float * | values | ||
| ) |
Find top-K indices and values from a score vector.
| scores | Input scores [n] |
| n | Number of scores (e.g., number of experts) |
| k | Number of top scores to select |
| indices | Output: indices of top-K scores [k], sorted descending by value |
| values | Output: top-K score values [k], sorted descending (can be NULL) |
Definition at line 49 of file topk_kernels.c.
Referenced by topk_batched_f32(), and topk_softmax_f32().
| void topk_softmax_f32 | ( | const float * | scores, |
| int | n, | ||
| int | k, | ||
| int * | indices, | ||
| float * | weights | ||
| ) |
Find top-K indices with softmax-normalized weights.
| scores | Input scores [n] (router logits) |
| n | Number of scores |
| k | Number of top scores to select |
| indices | Output: indices of top-K scores [k] |
| weights | Output: softmax-normalized weights for selected [k], sum to 1.0 |
Definition at line 134 of file topk_kernels.c.
References topk_f32().
Referenced by topk_batched_f32().
| void unfused_rmsnorm_qkv_prefill | ( | const float * | x, |
| const float * | gamma, | ||
| const float * | Wq, | ||
| const float * | Wk, | ||
| const float * | Wv, | ||
| float * | x_norm, | ||
| float * | Q, | ||
| float * | K, | ||
| float * | V, | ||
| int | seq_len, | ||
| int | hidden, | ||
| int | q_dim, | ||
| int | kv_dim, | ||
| float | eps | ||
| ) |
Unfused version for benchmarking comparison.
Unfused version for benchmarking comparison.
Definition at line 667 of file prefill_fused_gemm.c.
References gemm_tile_nt_strided(), PREFILL_TILE_M, PREFILL_TILE_N, and rmsnorm_tile().
| void vec_dot_q6_k_q8_k | ( | int | n, |
| float * | s, | ||
| const void * | vx, | ||
| const void * | vy | ||
| ) |
Q6_K x Q8_K dot product (single row)
Definition at line 954 of file gemm_kernels_q6k_q8k.c.
References dot_q6_k_q8_k_ref().
| void weighted_sum_f32 | ( | float * | y, |
| const float ** | vectors, | ||
| const float * | weights, | ||
| int | k, | ||
| int | n | ||
| ) |
Weighted sum of k vectors: y = sum_i(weights[i] * vectors[i])
| y | Output vector [n] |
| vectors | Array of k input vector pointers, each [n] |
| weights | Array of k scalar weights |
| k | Number of vectors to combine |
| n | Vector length |
Definition at line 155 of file axpy_kernels.c.
References axpy_f32(), and scal_copy_f32().