API Reference
Complete API documentation for C-Kernel-Engine kernels. All functions are exported from libckernel_engine.so and can be called from C or via Python ctypes.
This documentation is extracted from the C header files using Doxygen. Functions marked Forward compute activations, and Backward compute gradients.
Quick Reference
Include Header
#include "ckernel_engine.h"
Link Library
-lckernel_engine
Python ctypes
lib = ctypes.CDLL("libckernel_engine.so")
Memory Layouts
All kernels use consistent memory layouts optimized for cache efficiency:
| Buffer | Layout | Description |
|---|---|---|
input/output |
[B, T, D] | Batch × Tokens × Embedding dimension |
Q |
[H, T, d_k] | num_heads × Tokens × head_dim (head-major) |
K, V |
[H_kv, T, d_k] | num_kv_heads × Tokens × head_dim (for GQA) |
scores |
[H, T, T] | num_heads × query_tokens × key_tokens |
weights |
[out, in] | Row-major weight matrices |
Kernel Functions
GEMM (Matrix Multiplication) (92)
ck_gemm_add_bias
Forward
void ck_gemm_add_bias(float * C, const float * bias, int M, int N)
ck_gemm_nt_head_major_q5_0
Forward
void ck_gemm_nt_head_major_q5_0(const float * attn_out, const void * wo, const float * bias, float * output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (auto-dispatch)
ck_gemm_nt_head_major_q8_0
Forward
void ck_gemm_nt_head_major_q8_0(const float * attn_out, const void * wo, const float * bias, float * output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (Q8_0 weights)
ck_gemm_nt_quant
Forward
void ck_gemm_nt_quant(const float * A, const void * B, const float * bias, float * C, int M, int N, int K, CKDataType dtype)
ck_test_gemm_q4_k
Forward
void ck_test_gemm_q4_k(const void * weight_q4k, const float * input_f32, float * output, int rows, int cols, int n_tokens)
Q4_K GEMM - batched matrix multiply with quantized weights.
ck_test_gemm_q5_0
Forward
void ck_test_gemm_q5_0(const void * weight_q5_0, const float * input_f32, float * output, int rows, int cols, int n_tokens)
Test Q5_0 x Q8_0 GEMM (batch matrix multiply)
ck_test_gemm_q5_1
Forward
void ck_test_gemm_q5_1(const void * weight_q5_1, const float * input_f32, float * output, int rows, int cols, int n_tokens)
Test Q5_1 x Q8_0 GEMM (batch matrix multiply)
ck_test_gemm_q5_k
Forward
void ck_test_gemm_q5_k(const void * weight_q5_k, const float * input_f32, float * output, int rows, int cols, int n_tokens)
Test Q5_K x Q8_K GEMM (batch matrix multiply)
ck_test_gemm_q6_k
Forward
void ck_test_gemm_q6_k(const void * weight_q6k, const float * input_f32, float * output, int rows, int cols, int n_tokens)
Test Q6_K x Q8_K GEMM (batch matrix multiply)
ck_test_gemm_q8_0
Forward
void ck_test_gemm_q8_0(const void * weight_q8_0, const float * input_f32, float * output, int rows, int cols, int n_tokens)
Test Q8_0 x Q8_0 GEMM (batch matrix multiply)
ckernel_sgemm_native
Forward
void ckernel_sgemm_native(int M, int N, int K, const float * A, int lda, const float * B, int ldb, const float * bias, float * C, int ldc)
Native GEMM backend that directly reuses the C-Transformer GEMM kernel.
compute_gemm_params
Forward
void compute_gemm_params(const CPUInfo * cpu, GEMMParams * params)
fused_rmsnorm_gemm_2d_tiled
Forward
void fused_rmsnorm_gemm_2d_tiled(const float * x, const float * gamma, const float * W, float * output, int seq_len, int hidden, int out_dim, float eps, float * x_norm_scratch)
Fused RMSNorm + single GEMM with 2D tiling (weight reuse)
gemm_avx512_parallel
Forward
void gemm_avx512_parallel(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_batch_int8_impl_name
Forward
const char * gemm_batch_int8_impl_name(void)
Get the best implementation name for logging/debugging.
gemm_bf16_fp32out
Forward
void gemm_bf16_fp32out(const uint16_t * A, const uint16_t * B, const float * bias, float * C, int M, int N, int K)
gemm_bias_gelu_fused
Forward
void gemm_bias_gelu_fused(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_bias_relu_fused
Forward
void gemm_bias_relu_fused(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_bias_silu_fused
Forward
void gemm_bias_silu_fused(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_blocked_serial
Forward
void gemm_blocked_serial(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_blocked_serial_bf16
Forward
void gemm_blocked_serial_bf16(const uint16_t * A, const uint16_t * B, const uint16_t * bias, uint16_t * C, int M, int N, int K)
gemm_f16
Forward
void gemm_f16(float * Y, const uint16_t * W, const float * X, int M, int N, int K)
Auto-dispatch GEMM based on available SIMD.
gemm_f16_backward
Backward
void gemm_f16_backward(float * dX, const uint16_t * W, const float * dY, int M, int N, int K)
Batched backward pass.
gemm_f16_ref
Forward
void gemm_f16_ref(float * Y, const uint16_t * W, const float * X, int M, int N, int K)
Matrix-matrix multiply with FP16 weights (scalar reference)
gemm_fine_grained_parallel
Forward
void gemm_fine_grained_parallel(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_get_backend
const char * gemm_get_backend(void)
gemm_init_threads
Forward
void gemm_init_threads(void)
gemm_microkernel
Forward
void gemm_microkernel(const float * A, const float * B, float * C, int M, int N, int K, int B_transposed)
gemm_microkernel_blocked
Forward
void gemm_microkernel_blocked(const float * A, const float * B, float * C, int M, int N, int K)
gemm_microkernel_blocked_bt
Forward
void gemm_microkernel_blocked_bt(const float * A, const float * B, float * C, int M, int N, int K)
gemm_microkernel_edge
Forward
void gemm_microkernel_edge(int m, int n, int K, const float * A, int lda, const float * B, int ldb, float * C, int ldc, int first_k)
gemm_microkernel_packed
Forward
void gemm_microkernel_packed(const float * A, const float * B, float * C, int M, int N, int K)
gemm_microkernel_sequential
Forward
void gemm_microkernel_sequential(const float * A, const float * B, float * C, int M, int N, int K)
gemm_naive_parallel
Forward
void gemm_naive_parallel(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_naive_serial_double
Forward
void gemm_naive_serial_double(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_nn_avx512
Forward
void gemm_nn_avx512(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_nn_bf16
Forward
void gemm_nn_bf16(const uint16_t * A, const uint16_t * B, const uint16_t * bias, uint16_t * C, int M, int N, int K)
gemm_nn_blocked
Forward
void gemm_nn_blocked(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_nn_parallel
Forward
void gemm_nn_parallel(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_nn_serial_double
Forward
void gemm_nn_serial_double(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_nt
Forward
void gemm_nt(const float * input, const float * weight, float * output, int rows, int cols, int common)
gemm_nt_matvec_parallel
Forward
void gemm_nt_matvec_parallel(const float * A, const float * B, const float * bias, float * C, int N, int K)
gemm_nt_q4_0
Forward
void gemm_nt_q4_0(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
gemm_nt_q4_1
Forward
void gemm_nt_q4_1(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
GEMM with transposed Q4_1 weights: C = A @ B^T.
gemm_nt_q4_k
Forward
void gemm_nt_q4_k(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
gemm_nt_q4_k_q8_k
Forward
void gemm_nt_q4_k_q8_k(const void * A_q8, const void * B, const float * bias, float * C, int M, int N, int K)
gemm_nt_q5_0
Forward
void gemm_nt_q5_0(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
gemm_nt_q5_0_q8_0
Forward
void gemm_nt_q5_0_q8_0(const void * A_q8, const void * B_q5, const float * bias, float * C, int M, int N, int K)
Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.
gemm_nt_q5_0_q8_0_ref
Forward
void gemm_nt_q5_0_q8_0_ref(const void * A, const void * B, float * C, int M, int N, int K)
Dispatcher for gemm_nt_q8_0_q8_0.
gemm_nt_q5_0_q8_0_unroll_avx
Forward
void gemm_nt_q5_0_q8_0_unroll_avx(const void * A_q8, const void * B_q5, const float * bias, float * C, int M, int N, int K)
gemm_nt_q5_0_ref
Forward
void gemm_nt_q5_0_ref(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
GEMM with transposed Q5_0 weights: C = A @ B^T.
gemm_nt_q5_0_sse
Forward
void gemm_nt_q5_0_sse(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
gemm_nt_q5_0_sse_v2
Forward
void gemm_nt_q5_0_sse_v2(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
gemm_nt_q5_1
Forward
void gemm_nt_q5_1(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
GEMM with transposed Q5_1 weights: C = A @ B^T.
gemm_nt_q5_k
Forward
void gemm_nt_q5_k(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
gemm_nt_q5_k_ref
Forward
void gemm_nt_q5_k_ref(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
gemm_nt_q6_k
Forward
void gemm_nt_q6_k(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
gemm_nt_q6_k_q8_k
Forward
void gemm_nt_q6_k_q8_k(const void * A_q8, const void * B, const float * bias, float * C, int M, int N, int K)
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.
gemm_nt_q6_k_ref
Forward
void gemm_nt_q6_k_ref(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
gemm_nt_q6_k_sse
Forward
void gemm_nt_q6_k_sse(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
gemm_nt_q8_0
Forward
void gemm_nt_q8_0(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
gemm_nt_q8_0_dispatch
Forward
void gemm_nt_q8_0_dispatch(const void * A_q8, const void * B, const float * bias, float * C, int M, int N, int K, CKDataType dt)
gemm_nt_q8_0_mlp_dispatch
Forward
void gemm_nt_q8_0_mlp_dispatch(const void * A_q8, const void * B, const float * bias, float * C, int M, int N, int K, CKDataType dt)
gemm_nt_q8_0_q8_0
Forward
void gemm_nt_q8_0_q8_0(const void * A, const void * B, const float * bias, float * C, int M, int N, int K)
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
gemm_nt_q8_0_q8_0_ref
Forward
void gemm_nt_q8_0_q8_0_ref(const void * A, const void * B, float * C, int M, int N, int K)
Scalar reference: gemm_nt_q8_0_q8_0.
gemm_nt_q8_k_mlp_dispatch
Forward
void gemm_nt_q8_k_mlp_dispatch(const void * A_q8, const void * B, const float * bias, float * C, int M, int N, int K, CKDataType dt)
gemm_nt_q8_k_qkv_dispatch
Forward
void gemm_nt_q8_k_qkv_dispatch(const void * A_q8k, const void * B, const float * bias, float * C, int M, int N, int K, CKDataType dt)
gemm_q4_0
Forward
void gemm_q4_0(float * Y, const void * W, const float * X, int M, int N, int K)
Matrix-matrix multiply with Q4_0 weights.
gemm_q4_0_backward
Backward
void gemm_q4_0_backward(float * dX, const void * W, const float * dY, int M, int N, int K)
Batched backward pass.
gemm_q4_1
Forward
void gemm_q4_1(float * Y, const void * W, const float * X, int M, int N, int K)
Matrix-matrix multiply with Q4_1 weights.
gemm_q4_1_backward
Backward
void gemm_q4_1_backward(float * dX, const void * W, const float * dY, int M, int N, int K)
Batched backward pass.
gemm_q4_k
Forward
void gemm_q4_k(float * Y, const void * W, const float * X, int M, int N, int K)
Auto-dispatch GEMM based on available SIMD.
gemm_q4_k_backward
Backward
void gemm_q4_k_backward(float * dX, const void * W, const float * dY, int M, int N, int K)
Batched backward pass.
gemm_q4_k_q8_k
Forward
void gemm_q4_k_q8_k(float * Y, const void * W, const void * X_q8, int M, int N, int K)
gemm_q4_k_q8_k_ref
Forward
void gemm_q4_k_q8_k_ref(float * Y, const void * W, const void * X_q8, int M, int N, int K)
gemm_q4_k_ref
Forward
void gemm_q4_k_ref(float * Y, const void * W, const float * X, int M, int N, int K)
Matrix-matrix multiply with Q4_K weights (scalar reference)
gemm_q5_0
Forward
void gemm_q5_0(float * Y, const void * W, const float * X, int M, int N, int K)
Matrix-matrix multiply with Q5_0 weights.
gemm_q5_0_backward
Backward
void gemm_q5_0_backward(float * dX, const void * W, const float * dY, int M, int N, int K)
Batched backward pass.
gemm_q5_1
Forward
void gemm_q5_1(float * Y, const void * W, const float * X, int M, int N, int K)
Matrix-matrix multiply with Q5_1 weights.
gemm_q5_1_backward
Backward
void gemm_q5_1_backward(float * dX, const void * W, const float * dY, int M, int N, int K)
Batched backward pass.
gemm_q6_k
Forward
void gemm_q6_k(float * Y, const void * W, const float * X, int M, int N, int K)
gemm_q6_k_q8_k
Forward
void gemm_q6_k_q8_k(float * Y, const void * W, const void * X_q8, int M, int N, int K)
GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K.
gemm_q8_0
Forward
void gemm_q8_0(float * Y, const void * W, const float * X, int M, int N, int K)
Matrix-matrix multiply with Q8_0 weights.
gemm_q8_0_backward
Backward
void gemm_q8_0_backward(float * dX, const void * W, const float * dY, int M, int N, int K)
Batched backward pass.
gemm_swiglu_fused
Forward
void gemm_swiglu_fused(const float * x, const float * W_gate, const float * W_up, const float * b_gate, const float * b_up, float * output, int M, int N, int K)
gemm_tile_nt_strided
Forward
void gemm_tile_nt_strided(const float * A, const float * B_tile, float * C, int tile_m, int tile_n, int K, int C_stride)
GEMM tile with N-dimension tiling (weight reuse)
gemm_tn_avx512
Forward
void gemm_tn_avx512(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_tn_bf16
Forward
void gemm_tn_bf16(const uint16_t * A, const uint16_t * B, const uint16_t * bias, uint16_t * C, int M, int N, int K)
gemm_tn_blocked
Forward
void gemm_tn_blocked(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_tn_parallel
Forward
void gemm_tn_parallel(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
gemm_tn_serial_double
Forward
void gemm_tn_serial_double(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)
get_gemm_params
Forward
const GEMMParams * get_gemm_params(void)
Layer Normalization (10)
layernorm_backward_kernel
Backward
void layernorm_backward_kernel(const float * d_output, const float * input, const float * gamma, const float * mean, const float * rstd, float * d_input, float * d_gamma, float * d_beta, int tokens, int d_model, int aligned_embed_dim)
Backward pass / gradient computation
layernorm_backward_kernel_bf16
Backward
void layernorm_backward_kernel_bf16(const uint16_t * d_output, const uint16_t * input, const float * gamma, const float * mean, const float * rstd, uint16_t * d_input, float * d_gamma, float * d_beta, int tokens, int d_model, int aligned_embed_dim, float * scratch_d_output, float * scratch_input, float * scratch_d_input)
Backward pass / gradient computation
layernorm_forward_rolled_slice
Forward
void layernorm_forward_rolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps)
Forward pass computation
layernorm_forward_rolled_slice_bf16
Forward
void layernorm_forward_rolled_slice_bf16(const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps, float * scratch_input, float * scratch_output)
Forward pass computation
layernorm_forward_unrolled_slice
Forward
void layernorm_forward_unrolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)
Forward pass computation
layernorm_forward_unrolled_slice_bf16
Forward
void layernorm_forward_unrolled_slice_bf16(const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps, float * scratch_input, float * scratch_output)
Forward pass computation
layernorm_forward_unrolled_slice_scalar
Forward
void layernorm_forward_unrolled_slice_scalar(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)
Forward pass computation
layernorm_naive_serial
Forward
void layernorm_naive_serial(const float * input, const float * gamma, const float * beta, float * output, float * mean_cache, float * rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
layernorm_naive_serial_matched_precision
Forward
void layernorm_naive_serial_matched_precision(const float * input, const float * gamma, const float * beta, float * output, float * mean_cache, float * rstd_cache, int tokens, int d_model, float eps)
zero_layernorm_padding
Forward
void zero_layernorm_padding(float * out_ptr, int d_model, int aligned_embed_dim)
RMS Normalization (42)
ck_layer_backward_rmsnorm_swiglu
Backward
void ck_layer_backward_rmsnorm_swiglu(const CKLayerBackwardParams * p)
Backward pass / gradient computation
ck_layer_forward_rmsnorm_swiglu
Forward
void ck_layer_forward_rmsnorm_swiglu(const CKLayerForwardParams * p)
Forward pass computation
ck_layer_forward_rmsnorm_swiglu_decode
Forward
void ck_layer_forward_rmsnorm_swiglu_decode(const CKLayerForwardParams * p, int token_index, int cache_capacity)
Forward pass computation
ck_layer_forward_rmsnorm_swiglu_decode_fused
Forward
void ck_layer_forward_rmsnorm_swiglu_decode_fused(const CKLayerForwardParams * p, int token_index, int cache_capacity)
Forward pass computation
ck_layer_forward_rmsnorm_swiglu_decode_fused_attn
Forward
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn(const CKLayerForwardParams * p, int token_index, int cache_capacity)
Forward pass computation
ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl
Forward
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(const CKLayerForwardParams * p, int token_index, int cache_capacity, int fuse_mlp)
Forward pass computation
ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp
Forward
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp(const CKLayerForwardParams * p, int token_index, int cache_capacity)
Forward pass computation
ck_layer_forward_rmsnorm_swiglu_decode_q4_k
Forward
void ck_layer_forward_rmsnorm_swiglu_decode_q4_k(const CKLayerForwardParamsQ4K * p, int token_index, int cache_capacity)
Forward pass computation
ck_layer_forward_rmsnorm_swiglu_decode_quant
Forward
void ck_layer_forward_rmsnorm_swiglu_decode_quant(const CKLayerForwardParamsQ4K * p, int token_index, int cache_capacity)
Forward pass computation
ck_layer_forward_rmsnorm_swiglu_q4_k
Forward
void ck_layer_forward_rmsnorm_swiglu_q4_k(const CKLayerForwardParamsQ4K * p)
Forward pass computation
ck_layer_forward_rmsnorm_swiglu_quant
Forward
void ck_layer_forward_rmsnorm_swiglu_quant(const CKLayerForwardParamsQ4K * p)
Forward pass computation
ck_layer_forward_rmsnorm_swiglu_ref
Forward
void ck_layer_forward_rmsnorm_swiglu_ref(const CKLayerForwardParams * p)
Forward pass computation
ck_test_rmsnorm
Forward
void ck_test_rmsnorm(const float * input, const float * weight, float * output, int n_tokens, int dim, float eps)
RMSNorm.
fused_rmsnorm
Forward
void fused_rmsnorm(const float * input, const float * gamma, const float * beta, float * output, int hidden, float eps)
Fused RMSNorm - writes to pre-allocated buffer.
fused_rmsnorm_linear_q4k
Forward
void fused_rmsnorm_linear_q4k(float * y, const float * x, const float * gamma, const void * W_q4k, int M, int K, float eps)
Fused RMSNorm + Q4_K Linear projection.
fused_rmsnorm_qkv
Forward
void fused_rmsnorm_qkv(const float * input, const float * gamma, const float * W_qkv, const float * b_qkv, float * q_out, float * k_out, float * v_out, int hidden, int num_heads, int num_kv_heads, int head_dim, float eps)
Fused RMSNorm with fused QKV projection.
fused_rmsnorm_qkv_prefill
Forward
void fused_rmsnorm_qkv_prefill(const float * x, const float * gamma, const float * Wq, const float * Wk, const float * Wv, float * Q, float * K, float * V, int seq_len, int hidden, int q_dim, int kv_dim, float eps, float * scratch)
Fused RMSNorm + QKV projection for prefill (v3 optimized)
fused_rmsnorm_qkv_prefill_head_major
Forward
void fused_rmsnorm_qkv_prefill_head_major(const float * x, const float * gamma, const float * Wq, const float * Bq, const float * Wk, const float * Bk, const float * Wv, const float * Bv, float * Q, float * K, float * V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float * scratch)
Fused RMSNorm + QKV projection for prefill (head-major outputs)
fused_rmsnorm_qkv_prefill_head_major_quant
Forward
void fused_rmsnorm_qkv_prefill_head_major_quant(const float * x, const float * gamma, const void * Wq, const float * Bq, CKDataType wq_dt, const void * Wk, const float * Bk, CKDataType wk_dt, const void * Wv, const float * Bv, CKDataType wv_dt, float * Q, float * K, float * V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void * scratch)
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size
Forward
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
fused_rmsnorm_qkv_scratch_size
Forward
size_t fused_rmsnorm_qkv_scratch_size(int hidden)
Get scratch size for fused prefill.
mega_fuse_rmsnorm_qkv
Forward
void mega_fuse_rmsnorm_qkv(float * q_out, float * k_out, float * v_out, const float * input, const float * gamma, const float * W_qkv, const float * b_qkv, int hidden, int num_heads, int num_kv_heads, int head_dim, float eps)
Phase 1: Fused RMSNorm + QKV (intermediates in registers)
mega_fuse_rmsnorm_qkv_avx
Forward
void mega_fuse_rmsnorm_qkv_avx(float * q_out, float * k_out, float * v_out, const float * input, const float * gamma, const float * wq, const float * bq, const float * wk, const float * bk, const float * wv, const float * bv, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps)
Fused RMSNorm + QKV for decode (single token)
mega_fuse_rmsnorm_qkv_rope
Forward
void mega_fuse_rmsnorm_qkv_rope(float * q_out, float * k_out, float * v_out, const float * input, const float * gamma, const float * W_qkv, const float * b_qkv, const float * rope_cos, const float * rope_sin, int pos, int hidden, int num_heads, int num_kv_heads, int head_dim, int max_seq, float eps)
Phase 2: Fused RMSNorm + QKV + RoPE.
rmsnorm_backward
Backward
void rmsnorm_backward(const float * d_output, const float * input, const float * gamma, const float * rstd_cache, float * d_input, float * d_gamma, int tokens, int d_model, int aligned_embed_dim)
RMSNorm backward pass Testtest_rmsnorm.py::TestRMSNormBackward::test_backward_tokens test_rmsnorm.py::TestRMSNormBackward::test_backward_single test_parity.py::test_rmsnorm_backward_parity
rmsnorm_backward_bf16
Backward
void rmsnorm_backward_bf16(const uint16_t * d_output, const uint16_t * input, const float * gamma, const float * rstd_cache, uint16_t * d_input, float * d_gamma, int tokens, int d_model, int aligned_embed_dim)
Backward pass / gradient computation
rmsnorm_backward_int4
Backward
void rmsnorm_backward_int4(const uint8_t * d_output, const uint8_t * input, const float * gamma, const float * rstd_cache, uint8_t * d_input, float * d_gamma, int tokens, int d_model, int aligned_embed_dim, float * scratch_d_output, float * scratch_input, float * scratch_d_input)
Backward pass / gradient computation
rmsnorm_backward_int8
Backward
void rmsnorm_backward_int8(const int8_t * d_output, const int8_t * input, const float * gamma, const float * rstd_cache, int8_t * d_input, float * d_gamma, int tokens, int d_model, int aligned_embed_dim, float * scratch_d_output, float * scratch_input, float * scratch_d_input)
Backward pass / gradient computation
rmsnorm_forward
Forward
void rmsnorm_forward(const float * input, const float * gamma, float * output, float * rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
RMSNorm forward pass Testtest_rmsnorm.py::TestRMSNormForward::test_fp32_tokens test_rmsnorm.py::TestRMSNormForward::test_fp32_single test_rmsnorm.py::TestRMSNormForward::test_perf_rolled test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat test_parity.py::test_rmsnorm_parity
rmsnorm_forward_bf16
Forward
void rmsnorm_forward_bf16(const uint16_t * input, const float * gamma, uint16_t * output, float * rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
Forward pass computation
rmsnorm_forward_int4
Forward
void rmsnorm_forward_int4(const uint8_t * input, const float * gamma, uint8_t * output, float * rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float * scratch_input, float * scratch_output)
Forward pass computation
rmsnorm_forward_int8
Forward
void rmsnorm_forward_int8(const int8_t * input, const float * gamma, int8_t * output, float * rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float * scratch_input, float * scratch_output)
Forward pass computation
rmsnorm_q8_k_fused
Forward
void rmsnorm_q8_k_fused(const float * input, const float * gamma, void * vy, int tokens, int d_model, int aligned_embed_dim, float eps)
Fused RMSNorm + Q8_K Quantization
rmsnorm_qkv_fp32_fused
Forward
void rmsnorm_qkv_fp32_fused(const float * x, const float * rms_weight, const float * wq, const float * wk, const float * wv, float * q_out, float * k_out, float * v_out, int embed_dim, int q_dim, int kv_dim, float eps)
rmsnorm_qkv_fp32_fused_v2
Forward
void rmsnorm_qkv_fp32_fused_v2(const float * x, const float * rms_weight, const float * wq, const float * wk, const float * wv, float * q_out, float * k_out, float * v_out, int embed_dim, int q_dim, int kv_dim, float eps)
rmsnorm_qkv_fp32_fused_v3
Forward
void rmsnorm_qkv_fp32_fused_v3(const float * x, const float * rms_weight, const float * wq, const float * wk, const float * wv, float * q_out, float * k_out, float * v_out, int embed_dim, int q_dim, int kv_dim, float eps)
rmsnorm_qkv_q4k_fused
Forward
void rmsnorm_qkv_q4k_fused(const float * x, const float * rms_weight, const void * wq, const void * wk, const void * wv, float * q_out, float * k_out, float * v_out, int embed_dim, int q_dim, int kv_dim, float eps)
rmsnorm_qkv_separate_fp32
Forward
void rmsnorm_qkv_separate_fp32(const float * x, const float * rms_weight, const float * wq, const float * wk, const float * wv, float * normed, float * q_out, float * k_out, float * v_out, int embed_dim, int q_dim, int kv_dim, float eps)
rmsnorm_tile
Forward
void rmsnorm_tile(const float * input, const float * gamma, float * output, int tile_m, int embed_dim, int aligned_embed_dim, float eps)
Compute RMSNorm for a tile of tokens.
simple_rmsnorm
Forward
void simple_rmsnorm(const float * input, const float * gamma, float * output, int tokens, int d_model, float eps)
unfused_rmsnorm_linear_q4k_ref
Forward
void unfused_rmsnorm_linear_q4k_ref(float * y, const float * x, const float * gamma, const void * W_q4k, int M, int K, float eps)
Reference (unfused) implementation for correctness testing.
unfused_rmsnorm_qkv_prefill
Forward
void unfused_rmsnorm_qkv_prefill(const float * x, const float * gamma, const float * Wq, const float * Wk, const float * Wv, float * x_norm, float * Q, float * K, float * V, int seq_len, int hidden, int q_dim, int kv_dim, float eps)
Unfused version for comparison.
GELU Activation (10)
fast_gelu_scalar
Forward
float fast_gelu_scalar(float x)
gelu_backward_exact
Backward
void gelu_backward_exact(const float * input, const float * d_output, float * d_input, size_t n)
Backward pass / gradient computation
gelu_backward_exact_bf16
Backward
void gelu_backward_exact_bf16(const uint16_t * input, const uint16_t * d_output, uint16_t * d_input, size_t n, float * scratch_input, float * scratch_d_output, float * scratch_d_input)
Backward pass / gradient computation
gelu_backward_fast
Backward
void gelu_backward_fast(const float * input, const float * d_output, float * d_input, size_t n)
Backward pass / gradient computation
gelu_backward_fast_bf16
Backward
void gelu_backward_fast_bf16(const uint16_t * input, const uint16_t * d_output, uint16_t * d_input, size_t n, float * scratch_input, float * scratch_d_output, float * scratch_d_input)
Backward pass / gradient computation
gelu_backward_scalar
Backward
void gelu_backward_scalar(const float * input, const float * d_output, float * d_input, size_t n)
Backward pass / gradient computation
gelu_exact_inplace
Forward
void gelu_exact_inplace(float * data, size_t n)
gelu_fast_inplace
Forward
void gelu_fast_inplace(float * data, size_t n)
GELU activation forward (fast approximation, in-place) Testtest_gelu.py::TestGELUForward::test_gelu_fast_inplace test_gelu.py::TestGELUForward::test_gelu_vs_exact test_parity.py::test_gelu_parity
gelu_fast_inplace_bf16
Forward
void gelu_fast_inplace_bf16(uint16_t * data, size_t n, float * scratch)
gelu_scalar
Forward
float gelu_scalar(float x)
Softmax (11)
backward_causal_softmax_head_major
Backward
void backward_causal_softmax_head_major(float * d_scores, const float * weights, int num_heads, int num_tokens, int aligned_context_window)
Backward pass / gradient computation
backward_causal_softmax_head_major_bf16
Backward
void backward_causal_softmax_head_major_bf16(uint16_t * d_scores, const uint16_t * weights, int num_heads, int num_tokens, int aligned_context_window, float * scratch_d_scores, float * scratch_weights)
Backward pass / gradient computation
causal_softmax_head_major
Forward
void causal_softmax_head_major(float * scores, int num_heads, int num_tokens, int aligned_context_window)
Causal softmax (in-place, row-wise) Testtest_softmax.py::TestSoftmaxForward::test_causal_softmax test_softmax.py::TestSoftmaxForward::test_causal_vs_softmax test_attention.py::TestAttentionForward::test_softmax_correctness
causal_softmax_head_major_bf16
Forward
void causal_softmax_head_major_bf16(uint16_t * scores, int num_heads, int num_tokens, int aligned_context_window, float * scratch)
causal_softmax_head_major_exact
Forward
void causal_softmax_head_major_exact(float * scores, int num_heads, int num_tokens, int aligned_context_window)
Causal softmax (exact version using stdlib expf) Testtest_softmax.py::TestSoftmaxForward::test_causal_softmax_exact test_softmax.py::TestSoftmaxForward::test_exact_vs_fast
ck_test_softmax
Forward
void ck_test_softmax(const float * input, float * output, int n)
Softmax (simple, non-causal)
softmax
Forward
void softmax(float * x, int n)
softmax_cross_entropy_loss
Forward
void softmax_cross_entropy_loss(const float * logits, const int32_t * targets, int tokens, int vocab_size, float * d_logits, float * loss_out)
softmax_cross_entropy_loss_bf16
Forward
void softmax_cross_entropy_loss_bf16(const uint16_t * logits, const int32_t * targets, int tokens, int vocab_size, uint16_t * d_logits, float * loss_out, float * scratch_logits, float * scratch_d_logits)
softmax_inplace
Forward
void softmax_inplace(float * x, int n)
topk_softmax_f32
Forward
void topk_softmax_f32(const float * scores, int n, int k, int * indices, float * weights)
Find top-K indices with softmax-normalized weights.
Attention (48)
attention_backward_causal_head_major
Backward
void attention_backward_causal_head_major(const float * d_output, const float * q, const float * k, const float * v, const float * attn_weights, float * d_q, float * d_k, float * d_v, float * d_scores, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
Causal attention backward (non-GQA version) Testtest_attention_backward.py::TestAttentionBackward::test_backward test_attention_backward.py::TestAttentionBackward::test_backward_vs_separate test_parity.py::test_attention_backward_parity
attention_backward_causal_head_major_gqa
Backward
void attention_backward_causal_head_major_gqa(const float * d_output, const float * q, const float * k, const float * v, const float * attn_weights, float * d_q, float * d_k, float * d_v, float * d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
GQA causal attention backward (score-matrix version) Testtest_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_vs_separate test_parity.py::test_attention_backward_parity
attention_backward_causal_head_major_gqa_bf16
Backward
void attention_backward_causal_head_major_gqa_bf16(const uint16_t * d_output, float * d_x, const uint16_t * q, const uint16_t * k, const uint16_t * v, const float * attn_weights, float * d_q, float * d_k, float * d_v, float * d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float * scratch_d_output, float * scratch_q, float * scratch_k, float * scratch_v)
BF16 attention backward with caller-provided scratch buffers Testbf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_backward
attention_flash_cleanup
Forward
void attention_flash_cleanup(void)
Clean up flash attention resources.
attention_flash_decode
Forward
void attention_flash_decode(float * out, const float * q, const float * k, const float * v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.
attention_flash_decode_scalar
Forward
void attention_flash_decode_scalar(float * out, const float * q, const float * k, const float * v, int T_q, int T_k, int H, int D_h, float scale)
Scalar flash-style attention (online softmax)
attention_flash_init
Forward
void attention_flash_init(int max_context, int max_heads, int max_head_dim)
Initialize flash attention buffers.
attention_flash_query_causal
Forward
void attention_flash_query_causal(const float * q_vec, const float * k_head, const float * v_head, int kv_tokens, int head_dim, int aligned_head_dim, float scale, float * out_vec)
attention_flash_query_sliding
Forward
void attention_flash_query_sliding(const float * q_vec, const float * k_head, const float * v_head, int query_pos, int kv_tokens, int head_dim, int aligned_head_dim, float scale, float * out_vec, int sliding_window)
attention_forward_causal_head_major
Forward
void attention_forward_causal_head_major(const float * q, const float * k, const float * v, float * scores, float * output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
Causal attention forward (score-matrix version) Testtest_attention.py::TestAttentionForward::test_causal_forward test_attention.py::TestAttentionForward::test_gqa_broadcast test_attention.py::TestAttentionForward::test_exact_vs_fast test_parity.py::test_attention_parity
attention_forward_causal_head_major_exact
Forward
void attention_forward_causal_head_major_exact(const float * q, const float * k, const float * v, float * scores, float * output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
Causal attention forward (exact version using stdlib expf) Testtest_attention.py::TestAttentionForward::test_exact_single test_attention.py::TestAttentionForward::test_exact_vs_fast
attention_forward_causal_head_major_gqa
Forward
void attention_forward_causal_head_major_gqa(const float * q, const float * k, const float * v, float * scores, float * output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
GQA causal attention forward (score-matrix version) Testtest_attention.py::TestAttentionForward::test_gqa_forward test_attention.py::TestAttentionForward::test_gqa_broadcast test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward test_parity.py::test_attention_gqa_parity
attention_forward_causal_head_major_gqa_bf16
Forward
void attention_forward_causal_head_major_gqa_bf16(const uint16_t * q, const uint16_t * k, const uint16_t * v, float * scores, float * output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float * scratch_q, float * scratch_k, float * scratch_v)
BF16 GQA causal attention forward Testbf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_forward bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_flash
attention_forward_causal_head_major_gqa_exact
Forward
void attention_forward_causal_head_major_gqa_exact(const float * q, const float * k, const float * v, float * scores, float * output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
GQA causal attention forward (exact version using stdlib expf) Testtest_attention.py::TestAttentionForward::test_gqa_exact bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa
attention_forward_causal_head_major_gqa_flash
Forward
void attention_forward_causal_head_major_gqa_flash(const float * q, const float * k, const float * v, float * output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)
Flash attention forward for GQA (prefill, no score materialization) Testtest_flash_attention.py::TestFlashAttention::test_flash_forward test_flash_attention.py::TestFlashAttention::test_flash_vs_score_matrix test_flash_attention.py::TestFlashAttention::test_flash_gqa test_attention.py::TestAttentionForward::test_flash_forward
attention_forward_causal_head_major_gqa_flash_strided
Forward
void attention_forward_causal_head_major_gqa_flash_strided(const float * q, const float * k, const float * v, float * output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
Flash attention forward with custom KV stride (for KV cache) Testtest_flash_attention.py::TestFlashAttention::test_flash_strided test_kv_cache_attention.py::TestKVCacheAttention::test_flash_attention
attention_forward_causal_head_major_gqa_flash_strided_sliding
Forward
void attention_forward_causal_head_major_gqa_flash_strided_sliding(const float * q, const float * k, const float * v, float * output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window)
Flash attention forward with sliding window (prefill) Testtest_attention.py::TestAttentionForward::test_sliding_window_prefill
attention_forward_decode_head_major_gqa_flash
Forward
void attention_forward_decode_head_major_gqa_flash(const float * q_token, const float * k_cache, const float * v_cache, float * out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
Flash attention decode (single token attends to KV cache) Testtest_flash_attention.py::TestFlashAttention::test_flash_decode test_kv_cache_attention.py::TestKVCacheAttention::test_flash_decode test_fused_attention_decode.py::TestFusedAttentionDecode::test_flash_decode test_attention.py::TestAttentionForward::test_flash_decode
attention_forward_decode_head_major_gqa_flash_sliding
Forward
void attention_forward_decode_head_major_gqa_flash_sliding(const float * q_token, const float * k_cache, const float * v_cache, float * out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window)
Flash attention decode with sliding window Testtest_attention.py::TestAttentionForward::test_sliding_window_decode
attention_forward_decode_head_major_gqa_regular
Forward
void attention_forward_decode_head_major_gqa_regular(const float * q_token, const float * k_cache, const float * v_cache, float * out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!
attention_mlp_fused_fp32
Forward
void attention_mlp_fused_fp32(const float * q, const float * k_cache, const float * v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const float * wo, const float * residual_1, const float * rms_weight, float eps, const float * w_gate, const float * w_up, const float * w_down, int embed_dim, int intermediate_dim, float * hidden_out)
attention_mlp_fused_q4k
Forward
void attention_mlp_fused_q4k(const float * q, const float * k_cache, const float * v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const void * wo, const float * residual_1, const float * rms_weight, float eps, const void * w_gate, const void * w_up, const void * w_down, int embed_dim, int intermediate_dim, float * hidden_out)
attention_mlp_separate_fp32
Forward
void attention_mlp_separate_fp32(const float * q, const float * k_cache, const float * v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const float * wo, const float * residual_1, const float * rms_weight, float eps, const float * w_gate, const float * w_up, const float * w_down, int embed_dim, int intermediate_dim, float * attn_out_buf, float * hidden_after_attn_buf, float * normed_buf, float * gate_buf, float * up_buf, float * mlp_out_buf, float * hidden_out)
ck_attention_flash_decode_wrapper
Forward
void ck_attention_flash_decode_wrapper(const float * q_token, const float * k_cache, const float * v_cache, float * out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
Wrapper to call TRUE flash attention from orchestration layer.
ck_attention_project_head_major
Forward
void ck_attention_project_head_major(const float * attn_out, const float * wo, const float * bo, float * out, float * scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
ck_attention_project_head_major_backward
Backward
void ck_attention_project_head_major_backward(const float * d_out, const float * attn_out, const float * wo, float * d_attn_out, float * d_wo, float * d_bo, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Backward pass / gradient computation
ck_attention_project_head_major_decode_token
Forward
void ck_attention_project_head_major_decode_token(const float * attn_token, const float * wo, const float * bo, float * out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)
ck_attention_project_head_major_decode_token_residual
Forward
void ck_attention_project_head_major_decode_token_residual(const float * attn_token, const float * wo, const float * bo, const float * residual_in, float * proj_out, float * residual_out, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)
ck_attention_project_head_major_q4_k
Forward
void ck_attention_project_head_major_q4_k(const float * attn_out, const void * wo, const float * bo, float * out, float * scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
ck_attention_project_head_major_q4_k_q8_k
Forward
void ck_attention_project_head_major_q4_k_q8_k(const float * attn_out, const void * wo, const float * bo, float * out, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
ck_attention_project_head_major_quant
Forward
void ck_attention_project_head_major_quant(const float * attn_out, const void * wo, const float * bo, float * out, float * scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, CKDataType wo_dtype)
ck_attention_project_head_major_ref
Forward
void ck_attention_project_head_major_ref(const float * attn_out, const float * wo, const float * bo, float * out, float * scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
ck_test_attention_causal
Forward
void ck_test_attention_causal(const float * q, const float * k, const float * v, float * out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim)
Multi-head causal attention for prefill (head-major layout)
ck_test_attention_decode_sliding
Forward
void ck_test_attention_decode_sliding(const float * q_token, const float * k_cache, const float * v_cache, float * out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int sliding_window)
Test sliding-window attention (decode mode)
ck_test_attention_sliding_window
Forward
void ck_test_attention_sliding_window(const float * q, const float * k, const float * v, float * out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim, int sliding_window)
Test sliding-window attention (prefill)
fused_flash_attention_all_heads
Forward
void fused_flash_attention_all_heads(float * o_out, const float * q_all, const float * kv_cache_k, const float * kv_cache_v, int num_heads, int num_kv_heads, int head_dim, int seq_len, int kv_tile_size)
Fused Flash Attention for all heads (parallel dispatch)
fused_flash_attention_head
Forward
void fused_flash_attention_head(float * o_out, const float * q, const float * kv_cache_k, const float * kv_cache_v, int kv_head_idx, int seq_len, int head_dim, int kv_tile_size)
Fused Flash Attention for single head.
mega_fuse_flash_attention_avx
Forward
void mega_fuse_flash_attention_avx(float * o_out, const float * q, const float * kv_cache_k, const float * kv_cache_v, int num_heads, int num_kv_heads, int seq_len, int cache_capacity, int head_dim, int aligned_head_dim)
Flash attention with online softmax (AVX version)
mega_fused_attention
Forward
void mega_fused_attention(float * output, const float * input, const float * residual, const float * W_qkv, const float * b_qkv, const float * W_o, const float * b_o, float * kv_cache_k, float * kv_cache_v, const float * rope_cos, const float * rope_sin, int pos, int seq_len, int hidden, int num_heads, int num_kv_heads, int head_dim, int max_seq, float eps)
Complete mega-fused attention block.
mega_fused_attention_decode
Forward
void mega_fused_attention_decode(float * output, const float * input, const float * residual, const float * ln1_gamma, const float * wq, const float * bq, const float * wk, const float * bk, const float * wv, const float * bv, const float * wo, const float * bo, float * kv_cache_k, float * kv_cache_v, const float * rope_cos, const float * rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps)
Mega-fused attention for decode mode (single token)
mega_fused_attention_decode_q5_0
Forward
void mega_fused_attention_decode_q5_0(float * output, const float * input, const float * residual, const void * wq_q5_0, const void * wk_q5_0, const void * wv_q8_0, const void * wo_q5_0, const float * ln_gamma, const float * bq, const float * bk, const float * bv, const float * bo, float * kv_cache_k, float * kv_cache_v, const float * rope_cos, const float * rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void * scratch)
Serial mega-fused attention decode kernel.
mega_fused_attention_decode_q5_0_parallel_simd
Forward
void mega_fused_attention_decode_q5_0_parallel_simd(float * output, const float * input, const float * residual, const void * wq_q5_0, const void * wk_q5_0, const void * wv_q8_0, const void * wo_q5_0, const float * ln_gamma, const float * bq, const float * bk, const float * bv, const float * bo, float * kv_cache_k, float * kv_cache_v, const float * rope_cos, const float * rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void * scratch, int ith, int nth)
Parallel SIMD mega-fused attention decode kernel (threadpool-aware)
mega_fused_attention_decode_scratch_size
Forward
int mega_fused_attention_decode_scratch_size(int AE, int H, int KV, int AD)
Calculate scratch buffer size needed for the kernel.
mega_fused_attention_prefill
Forward
void mega_fused_attention_prefill(float * output, const float * input, const float * residual, const float * ln1_gamma, const void * wq, const float * bq, CKDataType wq_dt, const void * wk, const float * bk, CKDataType wk_dt, const void * wv, const float * bv, CKDataType wv_dt, const void * wo, const float * bo, CKDataType wo_dt, float * kv_cache_k, float * kv_cache_v, const float * rope_cos, const float * rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void * scratch)
Mega-fused attention for prefill mode (multiple tokens)
mega_fused_attention_prefill_q8_0
Forward
void mega_fused_attention_prefill_q8_0(float * output, const float * input, const float * residual, const float * ln1_gamma, const void * wq, const float * bq, CKDataType wq_dt, const void * wk, const float * bk, CKDataType wk_dt, const void * wv, const float * bv, CKDataType wv_dt, const void * wo, const float * bo, CKDataType wo_dt, float * kv_cache_k, float * kv_cache_v, const float * rope_cos, const float * rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void * scratch)
Mega-fused prefill attention kernel (Q8_0 out-proj)
mega_fused_attention_prefill_q8_0_scratch_size
Forward
size_t mega_fused_attention_prefill_q8_0_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Get scratch buffer size for mega_fused_attention_prefill_q8_0.
mega_fused_attention_prefill_scratch_size
Forward
size_t mega_fused_attention_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
Get scratch buffer size for mega_fused_attention_prefill.
simple_attention
Forward
void simple_attention(const float * q, const float * k, const float * v, float * output, int num_heads, int num_kv_heads, int seq_len, int head_dim)
MLP / Feed-Forward (30)
ck_mlp_swiglu_forward
Forward
void ck_mlp_swiglu_forward(const float * input, const float * w1, const float * b1, const float * w2, const float * b2, float * fc1_out, float * swiglu_out, float * output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
Forward pass computation
ck_mlp_swiglu_forward_fully_fused_token
Forward
void ck_mlp_swiglu_forward_fully_fused_token(const float * input_row, const float * w1, const float * b1, const float * w2, const float * b2, float * output_row, int aligned_embed_dim, int aligned_intermediate_dim)
Forward pass computation
ck_mlp_swiglu_forward_fused_token
Forward
void ck_mlp_swiglu_forward_fused_token(const float * input_row, const float * w1, const float * b1, const float * w2, const float * b2, float * swiglu_row, float * output_row, int aligned_embed_dim, int aligned_intermediate_dim)
Forward pass computation
ck_mlp_swiglu_forward_q4_k
Forward
void ck_mlp_swiglu_forward_q4_k(const float * input, const void * w1, const float * b1, const void * w2, const float * b2, float * fc1_out, float * swiglu_out, float * output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
Forward pass computation
ck_mlp_swiglu_forward_q4_k_q8_k
Forward
void ck_mlp_swiglu_forward_q4_k_q8_k(const float * input, const void * w1, const float * b1, const void * w2, const float * b2, float * fc1_out, float * swiglu_out, float * output, int aligned_embed_dim, int aligned_intermediate_dim)
Forward pass computation
ck_mlp_swiglu_forward_q4_k_q8_k_prefill
Forward
void ck_mlp_swiglu_forward_q4_k_q8_k_prefill(const float * input, const void * w1, const float * b1, const void * w2, const float * b2, float * fc1_out, float * swiglu_out, float * output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
Forward pass computation
ck_mlp_swiglu_forward_quant
Forward
void ck_mlp_swiglu_forward_quant(const float * input, const void * w1, const float * b1, CKDataType w1_dtype, const void * w2, const float * b2, CKDataType w2_dtype, float * fc1_out, float * swiglu_out, float * output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
Forward pass computation
ck_mlp_swiglu_forward_ref
Forward
void ck_mlp_swiglu_forward_ref(const float * input, const float * w1, const float * b1, const float * w2, const float * b2, float * fc1_out, float * swiglu_out, float * output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)
Forward pass computation
ck_test_outproj_mlp_fused_q5_0
Forward
void ck_test_outproj_mlp_fused_q5_0(const float * attn_out, const float * residual, const float * ln2_gamma, const void * wo, const void * w1, const void * w2, float * output, int tokens, int num_heads, int head_dim, int embed_dim, int intermediate, float eps, int w2_is_q6k)
Test mega-fused OutProj + MLP kernel (Q5_0 weights)
fused_mlp_swiglu_decode
Forward
void fused_mlp_swiglu_decode(const float * x, const float * W_gate, const float * W_up, const float * W_down, const float * b_gate, const float * b_up, const float * b_down, float * output, int D, int Hff)
fused_mlp_swiglu_decode_tiled
Forward
void fused_mlp_swiglu_decode_tiled(const float * x, const float * W_gate, const float * W_up, const float * W_down, const float * b_gate, const float * b_up, const float * b_down, float * output, int D, int Hff)
fused_mlp_swiglu_decode_v2
Forward
void fused_mlp_swiglu_decode_v2(const float * x, const float * W_gate, const float * W_up, const float * W_down, const float * b_gate, const float * b_up, const float * b_down, float * output, int D, int Hff)
fused_mlp_swiglu_prefill
Forward
void fused_mlp_swiglu_prefill(const float * x, const float * W_gate, const float * W_up, const float * W_down, float * output, int seq_len, int hidden, int intermediate, float * scratch)
Fused MLP (Gate + Up + SwiGLU + Down) for prefill.
fused_mlp_swiglu_prefill_bias
Forward
void fused_mlp_swiglu_prefill_bias(const float * x, const float * W_gate, const float * W_up, const float * W_down, const float * B_gate, const float * B_up, const float * B_down, float * output, int seq_len, int hidden, int intermediate, float * scratch)
Fused MLP for prefill with proper tiling.
fused_mlp_swiglu_prefill_w1w2_quant
Forward
void fused_mlp_swiglu_prefill_w1w2_quant(const float * x, const void * W1, const float * B1, CKDataType w1_dt, const void * W2, const float * B2, CKDataType w2_dt, float * output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void * scratch)
Quantized fused MLP for prefill (W1=gate+up, W2=down)
fused_mlp_swiglu_prefill_w1w2_quant_scratch_size
Forward
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(int aligned_embed_dim, int aligned_intermediate_dim)
Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.
fused_mlp_swiglu_scratch_size
Forward
size_t fused_mlp_swiglu_scratch_size(int intermediate)
Get scratch size for fused MLP.
layer_fused_attn_mlp_qkv_q4k
Forward
void layer_fused_attn_mlp_qkv_q4k(const float * q, const float * k_cache, const float * v_cache, int seq_len, float attn_scale, const void * wo, const float * rms_weight_mlp, const void * w_gate, const void * w_up, const void * w_down, const float * rms_weight_attn, const void * wq_next, const void * wk_next, const void * wv_next, const float * residual_in, int embed_dim, int intermediate_dim, int num_heads, int num_kv_heads, int head_dim, float eps, float * q_next, float * k_next, float * v_next, float * hidden_out)
mega_fused_outproj_mlp_prefill
Forward
void mega_fused_outproj_mlp_prefill(float * output, const float * attn_out, const float * residual, const float * ln2_gamma, const void * wo, const float * bo, CKDataType wo_dt, const void * w1, const float * b1, CKDataType w1_dt, const void * w2, const float * b2, CKDataType w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void * scratch)
Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill.
mega_fused_outproj_mlp_prefill_scratch_size
Forward
size_t mega_fused_outproj_mlp_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim)
Get scratch buffer size for mega_fused_outproj_mlp_prefill.
mlp_fused_fp32_v2
Forward
void mlp_fused_fp32_v2(const float * hidden_in, const float * rms_weight, float eps, const float * w_gate, const float * w_up, const float * w_down, int embed_dim, int intermediate_dim, float * hidden_out)
mlp_fused_fp32_v3
Forward
void mlp_fused_fp32_v3(const float * hidden_in, const float * rms_weight, float eps, const float * w_gate, const float * w_up, const float * w_down, int embed_dim, int intermediate_dim, float * hidden_out)
mlp_parallel
Forward
void mlp_parallel(const void * ln2_q8, const void * W_gate, const void * W_up, const void * W_down, float * gate_buf, float * up_buf, float * swiglu_buf, void * down_q8, float * mlp_out, int intermediate, int embed_dim, int num_threads)
Parallel MLP (gate/up + SwiGLU + down projection).
mlp_q8_0_dtype_supported
Forward
int mlp_q8_0_dtype_supported(CKDataType dt)
mlp_q8_k_dtype_supported
Forward
int mlp_q8_k_dtype_supported(CKDataType dt)
mlp_separate_fp32
Forward
void mlp_separate_fp32(const float * hidden_in, const float * rms_weight, float eps, const float * w_gate, const float * w_up, const float * w_down, float * normed_buf, float * gate_buf, float * up_buf, int embed_dim, int intermediate_dim, float * hidden_out)
mlp_token_parallel
Forward
void mlp_token_parallel(const float * input, const float * W_fc1, const float * b_fc1, const float * W_fc2, const float * b_fc2, float * fc1_output, float * output, int T, int aligned_dim, int num_threads)
mlp_token_parallel_bf16
Forward
void mlp_token_parallel_bf16(const uint16_t * input, const uint16_t * W_fc1, const uint16_t * b_fc1, const uint16_t * W_fc2, const uint16_t * b_fc2, float * fc1_output, float * output, int T, int aligned_dim, int num_threads, float * scratch_bias1_f, float * scratch_bias2_f, uint16_t * scratch_fc1_bf16)
Optimized MLP Forward (BF16 weights, FP32 activations)
mlp_token_parallel_bf16_fp32act
Forward
void mlp_token_parallel_bf16_fp32act(const uint16_t * input, const uint16_t * W_fc1, const uint16_t * b_fc1, const uint16_t * W_fc2, const uint16_t * b_fc2, float * fc1_output, float * output, int T, int aligned_dim, int num_threads, float * scratch_input_f, float * scratch_bias1_f, float * scratch_bias2_f, uint16_t * scratch_fc1_bf16)
Alternative: Fully FP32 activations throughout
mlp_token_parallel_exact
Forward
void mlp_token_parallel_exact(const float * input, const float * W_fc1, const float * b_fc1, const float * W_fc2, const float * b_fc2, float * fc1_output, float * output, int T, int aligned_dim, int num_threads)
Sigmoid Activation (5)
sigmoid_backward
Backward
void sigmoid_backward(const float * input, const float * d_output, float * d_input, size_t n)
Backward pass / gradient computation
sigmoid_backward_bf16
Backward
void sigmoid_backward_bf16(const uint16_t * input, const uint16_t * d_output, uint16_t * d_input, size_t n, float * scratch_input, float * scratch_d_output, float * scratch_d_input)
Backward pass / gradient computation
sigmoid_forward
Forward
void sigmoid_forward(const float * input, float * output, size_t n)
Forward pass computation
sigmoid_forward_bf16
Forward
void sigmoid_forward_bf16(const uint16_t * input, uint16_t * output, size_t n, float * scratch_input, float * scratch_output)
Forward pass computation
sigmoid_scalar
Forward
float sigmoid_scalar(float x)
SwiGLU Activation (7)
ck_test_swiglu
Forward
void ck_test_swiglu(const float * gate_up, float * output, int n_tokens, int intermediate_dim)
SwiGLU activation.
swiglu_backward
Backward
void swiglu_backward(const float * input, const float * d_output, float * d_input, int tokens, int dim)
SwiGLU backward pass Testtest_swiglu.py::TestSwiGLUBackward::test_backward_tokens test_swiglu.py::TestSwiGLUBackward::test_backward_single test_parity.py::test_swiglu_backward_parity
swiglu_backward_bf16
Backward
void swiglu_backward_bf16(const uint16_t * input, const uint16_t * d_output, uint16_t * d_input, int tokens, int dim)
Backward pass / gradient computation
swiglu_backward_exact
Backward
void swiglu_backward_exact(const float * input, const float * d_output, float * d_input, int tokens, int dim)
SwiGLU backward pass (exact version using stdlib sigmoid) Testtest_swiglu.py::TestSwiGLUBackward::test_exact_vs_fast test_swiglu.py::TestSwiGLUBackward::test_exact_single
swiglu_forward
Forward
void swiglu_forward(const float * input, float * output, int tokens, int dim)
SwiGLU forward pass Testtest_swiglu.py::TestSwiGLUForward::test_forward_tokens test_swiglu.py::TestSwiGLUForward::test_forward_single test_mlp.py::TestMLPForward::test_swiglu_mlp test_fused_swiglu_decode.py::TestFusedSwiGLUDecode::test_fused_swiglu_decode test_parity.py::test_swiglu_parity
swiglu_forward_bf16
Forward
void swiglu_forward_bf16(const uint16_t * input, uint16_t * output, int tokens, int dim)
Forward pass computation
swiglu_forward_exact
Forward
void swiglu_forward_exact(const float * input, float * output, int tokens, int dim)
SwiGLU forward pass (exact version using stdlib sigmoid) Testtest_swiglu.py::TestSwiGLUForward::test_exact_vs_fast test_swiglu.py::TestSwiGLUForward::test_exact_single
RoPE (Rotary Position Embedding) (22)
apply_rope
Forward
void apply_rope(float * x, int seq_len, int head_dim)
apply_rope_inline
Forward
void apply_rope_inline(float * q, float * k, const float * rope_cos, const float * rope_sin, int pos, int H, int KV, int AD)
ck_model_precompute_rope
Forward
void ck_model_precompute_rope(void * model)
Precompute RoPE cos/sin caches. Call once after allocation, before inference.
ck_test_rope
Forward
void ck_test_rope(float * q, float * k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)
RoPE (Rotary Position Embedding)
ck_test_rope_interleaved
Forward
void ck_test_rope_interleaved(float * q, float * k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)
RoPE with interleaved format (for llama.cpp compatibility)
fused_rope_inplace
Forward
void fused_rope_inplace(float * q, float * k, const float * rope_cos, const float * rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int max_seq)
Fused RoPE application (in-place on pre-allocated buffers)
mega_fuse_rope_inplace_avx
Forward
void mega_fuse_rope_inplace_avx(float * q, float * k, const float * rope_cos, const float * rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim)
Apply RoPE to Q and K (in-place, from L1)
model_precompute_rope
Forward
void model_precompute_rope(MODELModel * model)
qwen2_0_5b_decode_precompute_rope
Forward
void qwen2_0_5b_decode_precompute_rope(QWEN2_0_5B_DECODEModel * model)
rope_apply_head
Forward
void rope_apply_head(float * x, const float * cos_cache, const float * sin_cache, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
rope_backward
Backward
void rope_backward(const float * d_out, float * d_x, const float * cos_cache, const float * sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
RoPE backward (inverse rotation) Testtest_rope.py::TestRoPEBackward::test_rope_backward test_rope.py::TestRoPEBackward::test_rope_backward_vs_separate
rope_backward_bf16
Backward
void rope_backward_bf16(const uint16_t * d_out, uint16_t * d_x, const float * cos_cache, const float * sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float * scratch_d_out, float * scratch_d_x)
Backward pass / gradient computation
rope_backward_inplace
Backward
void rope_backward_inplace(float * d_x, const float * cos_cache, const float * sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
RoPE backward in-place (overwrite with inverse rotation) Testtest_rope.py::TestRoPEBackward::test_rope_backward_inplace
rope_backward_qk
Backward
void rope_backward_qk(const float * d_q_out, const float * d_k_out, float * d_q, float * d_k, const float * cos_cache, const float * sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
RoPE backward for both dQ and dK Testtest_rope.py::TestRoPEBackward::test_rope_backward_qk
rope_backward_qk_bf16
Backward
void rope_backward_qk_bf16(const uint16_t * d_q_out, const uint16_t * d_k_out, uint16_t * d_q, uint16_t * d_k, const float * cos_cache, const float * sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float * scratch_dq_out, float * scratch_dq, float * scratch_dk_out, float * scratch_dk)
Backward pass / gradient computation
rope_forward
Forward
void rope_forward(float * x, const float * cos_cache, const float * sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
RoPE forward (head-major layout, in-place) Testtest_rope.py::TestRoPEForward::test_rope_forward test_rope.py::TestRoPEForward::test_rope_vs_separate test_parity.py::test_rope_parity
rope_forward_bf16
Forward
void rope_forward_bf16(uint16_t * x, const float * cos_cache, const float * sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float * scratch)
Forward pass computation
rope_forward_qk
Forward
void rope_forward_qk(float * q, float * k, const float * cos_cache, const float * sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
RoPE forward for both Q and K (common inference pattern) Testtest_rope.py::TestRoPEForward::test_rope_forward_qk test_fused_attention_decode.py::TestFusedAttentionDecode::test_qk_rope test_parity.py::test_rope_qk_parity
rope_forward_qk_bf16
Forward
void rope_forward_qk_bf16(uint16_t * q, uint16_t * k, const float * cos_cache, const float * sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float * scratch_q, float * scratch_k)
Forward pass computation
rope_forward_qk_strided
Forward
void rope_forward_qk_strided(float * q, float * k, const float * cos_cache, const float * sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
RoPE forward for both Q and K with custom strides (KV cache layouts) Testtest_rope.py::TestRoPEForward::test_rope_forward_qk_strided test_kv_cache_attention.py::TestKVCacheAttention::test_qk_rope_strided
rope_forward_strided
Forward
void rope_forward_strided(float * x, const float * cos_cache, const float * sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int head_stride_tokens)
RoPE forward with custom head stride (for KV cache layouts) Testtest_rope.py::TestRoPEForward::test_rope_strided test_kv_cache_attention.py::TestKVCacheAttention::test_rope_decode
rope_precompute_cache
void rope_precompute_cache(float * cos_cache, float * sin_cache, int max_seq_len, int head_dim, float base)
Precompute RoPE cos/sin cache Testtest_rope.py::TestRoPECache::test_cache_computation test_rope.py::TestRoPECache::test_cache_values
Fully Connected Layers (2)
fc1_backward_kernel
Backward
void fc1_backward_kernel(const float * d_output, const float * fc1_input, const float * W_fc1, float * d_input, float * d_W_fc1, float * d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads)
Backward pass / gradient computation
fc2_backward_kernel
Backward
void fc2_backward_kernel(const float * d_output, const float * fc2_input, const float * W_fc2, float * d_input, float * d_W_fc2, float * d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads)
Backward pass / gradient computation
Other Functions (877)
_Static_assert
Forward
_Static_assert(sizeof(MagicHeader), "MagicHeader must be 64 bytes")
__attribute__
Forward
struct __attribute__((packed))
adamw_update_bf16
Forward
void adamw_update_bf16(const uint16_t * grad, uint16_t * weight, float * m, float * v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)
AdamW optimizer update (bf16 weights/gradients, fp32 optimizer state)
adamw_update_f32
Forward
void adamw_update_f32(const float * grad, float * weight, float * m, float * v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)
AdamW optimizer update (fp32 version)
add_backward_bf16
Backward
void add_backward_bf16(const uint16_t * d_y, uint16_t * d_a, uint16_t * d_b, size_t n)
Backward pass / gradient computation
add_bias_tile
Forward
void add_bias_tile(float * out, const float * bias, int tile_m, int out_dim)
add_forward_2d_bf16
Forward
void add_forward_2d_bf16(const uint16_t * a, const uint16_t * b, uint16_t * y, int tokens, int dim, int aligned_dim)
Forward pass computation
add_forward_bf16
Forward
void add_forward_bf16(const uint16_t * a, const uint16_t * b, uint16_t * y, size_t n)
Forward pass computation
add_forward_f32
Forward
void add_forward_f32(const float * a, const float * b, float * y, size_t n)
Element-wise add: y = a + b Testtest_add.py::TestAddForward::test_add_forward_f32 test_add.py::TestAddForward::test_add_inplace_f32 test_multi_layer_parity.py::TestMultiLayerParity::test_residual_add
add_inplace_bf16
Forward
void add_inplace_bf16(uint16_t * a, const uint16_t * b, size_t n)
add_inplace_f32
Forward
void add_inplace_f32(float * a, const float * b, size_t n)
add_scaled_forward_bf16
Forward
void add_scaled_forward_bf16(const uint16_t * a, const uint16_t * b, uint16_t * y, float alpha, size_t n)
Forward pass computation
add_scaled_inplace_bf16
Forward
void add_scaled_inplace_bf16(uint16_t * a, const uint16_t * b, float alpha, size_t n)
align_up
Forward
size_t align_up(size_t n, size_t align)
align_up_bytes
Forward
size_t align_up_bytes(size_t n, size_t align)
align_up_elems
Forward
size_t align_up_elems(size_t elems, size_t elem_bytes, size_t align_bytes)
align_up_size
Forward
size_t align_up_size(size_t value, size_t align)
amx_available
Forward
bool amx_available(void)
apply_bpe_merges
Forward
int apply_bpe_merges(CKTrueBPE * bpe, CKBPETokenList * list)
apply_chat_template
Forward
char * apply_chat_template(const ChatTemplate * tmpl, const char * system, const char * user)
arena_for_role
Forward
CKMemArenaKind arena_for_role(CKBufferRole role)
argmax_f32
Forward
int argmax_f32(const float * scores, int n)
Find index of maximum value.
axpy_2d_f32
Forward
void axpy_2d_f32(float * Y, const float * X, float alpha, int num_tokens, int dim, int y_stride, int x_stride)
Batched AXPY for 2D tensors: Y[t,:] += alpha * X[t,:].
axpy_f32
Forward
void axpy_f32(float * y, const float * x, float alpha, int n)
In-place AXPY: y += alpha * x.
axpy_zero_f32
Forward
void axpy_zero_f32(float * y, const float * x, float alpha, int n)
Zero output then accumulate: y = 0; y += alpha * x.
barrier_init
Forward
void barrier_init(ck_barrier_t * b, int n_threads)
barrier_wait
Forward
void barrier_wait(ck_barrier_t * b)
Spin-wait barrier. All threads must call this. Uses phase counter to allow re-use without reset.
bf16_tensor_to_float
Forward
void bf16_tensor_to_float(const uint16_t * src, float * dst, size_t count)
bf16_to_float
Forward
float bf16_to_float(uint16_t v)
buffer_bytes
Forward
size_t buffer_bytes(const CKIRV2Buffer * buf, const CKModelConfig * cfg, const CKV2AlignInfo * align)
buffer_enabled
Forward
int buffer_enabled(const CKIRV2Graph * graph, const CKIRV2Buffer * buf, int training_enabled)
build_plan
Forward
int build_plan(const CKIRV2Graph * graph, CKMemPlan * plan, size_t alignment_bytes, int training_enabled, int tokens_override)
bump_bytes
Forward
size_t bump_bytes(size_t * off, size_t bytes, size_t align)
byte_to_gpt2
Forward
int byte_to_gpt2(unsigned char byte, char * out)
cache_compare
int cache_compare(const void * a, const void * b)
ck_add_inplace
Forward
void ck_add_inplace(float * dst, const float * src, int tokens, int aligned_embed_dim)
ck_buffer_should_alloc
Forward
int ck_buffer_should_alloc(const CKBufferSpec * spec)
ck_buffer_uses_weight_dtype
Forward
int ck_buffer_uses_weight_dtype(const CKBufferSpec * spec)
ck_build_decoder_backward_ir
Backward
int ck_build_decoder_backward_ir(const CKIRGraph * forward, CKIRGraph * backward)
Build a naive backward IR graph from a forward decoder IR.
ck_build_decoder_ir
Forward
int ck_build_decoder_ir(const CKModelConfig * cfg, CKIRGraph * graph)
Build a simple decoder-only IR graph for the given config.
ck_codegen_c_skeleton
Forward
void ck_codegen_c_skeleton(const CKIRGraph * forward, const CKIRGraph * backward, FILE * out)
Emit a C skeleton for forward + backward execution based on the IR.
ck_codegen_emit_runtime
Forward
int ck_codegen_emit_runtime(const CKIRGraph * forward, const char * path, CKEmitMode mode)
Emit a C runtime file that stitches kernels for the given forward IR.
ck_codegen_v2_dtype_name
Forward
const char * ck_codegen_v2_dtype_name(CKDataType dtype)
ck_codegen_v2_emit_dispatch
Forward
void ck_codegen_v2_emit_dispatch(FILE * out, const CKIRV2Graph * graph)
ck_codegen_v2_emit_preamble
Forward
int ck_codegen_v2_emit_preamble(FILE * out)
ck_codegen_v2_emit_runtime
Forward
int ck_codegen_v2_emit_runtime(const CKIRV2Graph * graph, const char * path, CKEmitMode mode)
Emit a C runtime file from a CKIRV2Graph.
ck_codegen_v2_emit_schedule
Forward
void ck_codegen_v2_emit_schedule(FILE * out, const CKIRV2Graph * graph, const char * prefill_runtime, const char * decode_runtime, const char * backward_runtime)
ck_codegen_v2_emit_sections
Forward
void ck_codegen_v2_emit_sections(FILE * out, const CKIRV2Graph * graph, const CKMemPlan * prefill_plan, const CKMemPlan * decode_plan, const CKMemPlan * backward_plan)
ck_codegen_v2_emit_struct
Forward
void ck_codegen_v2_emit_struct(FILE * out, const CKIRV2Graph * graph, const CKMemPlan * plan, const char * tag)
ck_debug_check_buffer
Forward
void ck_debug_check_buffer(const char * stage, const float * buf, int size)
ck_debug_check_q4k_weights
Forward
void ck_debug_check_q4k_weights(const char * stage, const void * q4_buf, int num_blocks)
ck_debug_check_q8k
Forward
void ck_debug_check_q8k(const char * stage, const void * q8_buf, int num_blocks)
ck_dot_f32
Forward
float ck_dot_f32(const float * a, const float * b, int len)
ck_dtype_block_bytes
Forward
size_t ck_dtype_block_bytes(CKDataType dt)
Get bytes per block for quantized types.
ck_dtype_block_size
Forward
size_t ck_dtype_block_size(CKDataType dt)
Get the number of elements per quantization block.
ck_dtype_bytes
Forward
size_t ck_dtype_bytes(CKDataType dt)
Get bytes per element for non-quantized types.
ck_dtype_is_quantized
Forward
int ck_dtype_is_quantized(CKDataType dt)
Check if a data type is block-quantized (GGML-style)
ck_dtype_row_bytes
Forward
size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
ck_dtype_supported
Forward
int ck_dtype_supported(CKDataTypeMask mask, CKDataType dt)
ck_expf
Forward
float ck_expf(float x)
ck_fast_expf
Forward
float ck_fast_expf(float x)
ck_find_buffer_spec
Forward
const CKBufferSpec * ck_find_buffer_spec(const char * name)
ck_find_kernel_spec
Forward
const CKKernelSpec * ck_find_kernel_spec(const char * name)
ck_first_layer_buffer_name
Forward
const char * ck_first_layer_buffer_name(void)
ck_flash_attn_choose_tile_k
Forward
int ck_flash_attn_choose_tile_k(int D_h)
ck_flash_attn_fast_exp_kind
Forward
int ck_flash_attn_fast_exp_kind(void)
ck_flash_attn_tile_k
Forward
int ck_flash_attn_tile_k(int D_h)
ck_fma_f32_to_f16
Forward
void ck_fma_f32_to_f16(const float * a, const float * b, const float * c, uint16_t * dst, int n)
FMA in FP32, store result as FP16: dst = a * b + c.
ck_fp16_to_fp32
Forward
float ck_fp16_to_fp32(ck_half h)
ck_fp16_to_fp32_2d
Forward
void ck_fp16_to_fp32_2d(const uint16_t * src, float * dst, int rows, int cols, int src_stride, int dst_stride)
Convert 2D FP16 matrix to FP32 with strided access.
ck_fp16_to_fp32_row
Forward
void ck_fp16_to_fp32_row(const uint16_t * src, float * dst, int n)
Convert FP16 row to FP32 (auto-select best implementation)
ck_fp16_to_fp32_scalar
Forward
float ck_fp16_to_fp32_scalar(uint16_t h)
ck_fp16_to_fp32_soft
Forward
float ck_fp16_to_fp32_soft(ck_half h)
Convert FP16 (ck_half) to FP32 — software implementation.
ck_fp32_to_fp16
Forward
ck_half ck_fp32_to_fp16(float f)
ck_fp32_to_fp16_2d
Forward
void ck_fp32_to_fp16_2d(const float * src, uint16_t * dst, int rows, int cols, int src_stride, int dst_stride)
Convert 2D FP32 matrix to FP16 with strided access.
ck_fp32_to_fp16_inplace
Forward
void ck_fp32_to_fp16_inplace(float * data, void * scratch, int n)
Convert FP32 to FP16 in-place using scratch buffer.
ck_fp32_to_fp16_row
Forward
void ck_fp32_to_fp16_row(const float * src, uint16_t * dst, int n)
Convert FP32 row to FP16 (auto-select best implementation)
ck_fp32_to_fp16_scalar
Forward
uint16_t ck_fp32_to_fp16_scalar(float f)
ck_fp32_to_fp16_soft
Forward
ck_half ck_fp32_to_fp16_soft(float f)
Convert FP32 to FP16 (ck_half) — software implementation.
ck_get_block_q4_k_size
Forward
int ck_get_block_q4_k_size(void)
Get Q4_K block size in bytes.
ck_get_block_q5_1_size
Forward
int ck_get_block_q5_1_size(void)
Get Q5_1 block size in bytes (24 bytes per 32 weights)
ck_get_block_q5_k_size
Forward
int ck_get_block_q5_k_size(void)
Get Q5_K block size in bytes (176 bytes per 256 weights)
ck_get_block_q6_k_size
Forward
int ck_get_block_q6_k_size(void)
Get Q6_K block size in bytes.
ck_get_block_q8_k_size
Forward
int ck_get_block_q8_k_size(void)
Get Q8_K block size in bytes.
ck_get_capabilities
Forward
ck_capability_t ck_get_capabilities(void)
Get current platform capabilities.
ck_get_num_threads
Forward
int ck_get_num_threads(void)
ck_get_physical_cores
Forward
int ck_get_physical_cores(void)
ck_get_qk5_1
Forward
int ck_get_qk5_1(void)
Get QK5_1 (elements per Q5_1 block)
ck_get_qk_k
Forward
int ck_get_qk_k(void)
Get QK_K (elements per super-block)
ck_get_threadpool
Forward
ck_threadpool_t * ck_get_threadpool(void)
Get the global thread pool handle for dispatch. Convenience wrapper — initializes on first call.
ck_huge_alloc
Forward
void * ck_huge_alloc(size_t bytes)
Allocate a large, contiguous memory region for model weights/activations.
ck_huge_free
Forward
void ck_huge_free(void * ptr, size_t bytes)
Free memory allocated by ck_huge_alloc.
ck_ir_dump
Forward
void ck_ir_dump(const CKIRGraph * graph, FILE * out)
Dump a human-readable view of the IR to the given stream.
ck_ir_free
Forward
void ck_ir_free(CKIRGraph * graph)
Free any heap-allocated memory owned by the graph.
ck_ir_parse_json
Forward
int ck_ir_parse_json(const char * path, CKIRGraph * graph)
Parse a JSON IR map file (as produced by ck_ir_serialize_json) back into a CKIRGraph. This enables a two-stage pipeline:
ck_ir_serialize_json
Forward
int ck_ir_serialize_json(const CKIRGraph * graph, const char * path)
Serialize a CKIRGraph to a simple JSON IR map file.
ck_ir_v2_align_up_bytes
Forward
size_t ck_ir_v2_align_up_bytes(size_t n, size_t align)
ck_ir_v2_align_up_elems
Forward
size_t ck_ir_v2_align_up_elems(size_t elems, size_t elem_bytes, size_t align_bytes)
ck_ir_v2_apply_meta
Forward
int ck_ir_v2_apply_meta(const char * path, CKIRV2Graph * graph)
ck_ir_v2_apply_weight_dtypes
Forward
int ck_ir_v2_apply_weight_dtypes(const char * json, const char * end, CKIRV2Graph * graph)
ck_ir_v2_build_decoder
Forward
int ck_ir_v2_build_decoder(const CKModelConfig * cfg, CKIRV2Graph * graph)
ck_ir_v2_build_decoder_backward
Backward
int ck_ir_v2_build_decoder_backward(const CKIRV2Graph * forward, CKIRV2Graph * backward)
Backward pass / gradient computation
ck_ir_v2_copy_buffer_spec
Forward
int ck_ir_v2_copy_buffer_spec(const CKBufferSpec * spec, CKIRV2Buffer * out)
ck_ir_v2_copy_shape
Forward
void ck_ir_v2_copy_shape(CKDimToken * dst, const CKDimToken * src)
ck_ir_v2_dim_kind_from_name
Forward
CKDimKind ck_ir_v2_dim_kind_from_name(const char * name)
ck_ir_v2_dim_name
Forward
const char * ck_ir_v2_dim_name(CKDimKind dim)
ck_ir_v2_dtype_name
Forward
const char * ck_ir_v2_dtype_name(CKDataType dtype)
ck_ir_v2_emit_dimensions
Forward
void ck_ir_v2_emit_dimensions(FILE * out, const CKModelConfig * cfg, const CKIRV2AlignInfo * align, int tokens_override)
ck_ir_v2_emit_memory_plan
Forward
void ck_ir_v2_emit_memory_plan(FILE * out, const CKIRV2Graph * graph, const CKMemPlan * plan)
ck_ir_v2_emit_resolved_shape
Forward
void ck_ir_v2_emit_resolved_shape(FILE * out, const CKModelConfig * cfg, const CKIRV2AlignInfo * align, const CKDimToken * shape, int tokens_override)
ck_ir_v2_emit_shape
Forward
int ck_ir_v2_emit_shape(FILE * out, const CKDimToken * shape)
ck_ir_v2_find_array_end
Forward
const char * ck_ir_v2_find_array_end(const char * open, const char * end)
ck_ir_v2_find_buffer_index
Forward
int ck_ir_v2_find_buffer_index(const CKIRV2Graph * graph, const char * name)
ck_ir_v2_find_buffer_spec
Forward
const CKBufferSpec * ck_ir_v2_find_buffer_spec(const char * name)
ck_ir_v2_find_kernel_spec
Forward
const CKKernelSpec * ck_ir_v2_find_kernel_spec(const char * name)
ck_ir_v2_find_key
Forward
const char * ck_ir_v2_find_key(const char * json, const char * key, const char * end)
ck_ir_v2_free
Forward
void ck_ir_v2_free(CKIRV2Graph * graph)
ck_ir_v2_free_buffer
Forward
void ck_ir_v2_free_buffer(CKIRV2Buffer * buf)
ck_ir_v2_free_node
Forward
void ck_ir_v2_free_node(CKIRV2Node * node)
ck_ir_v2_lower_copy_buffers
Forward
int ck_ir_v2_lower_copy_buffers(const CKIRV2Graph * input, CKIRV2Graph * output)
ck_ir_v2_lower_copy_nodes
Forward
int ck_ir_v2_lower_copy_nodes(const CKIRV2Graph * input, CKIRV2LowerMode mode, CKIRV2Graph * output)
ck_ir_v2_lower_emit_json
Forward
int ck_ir_v2_lower_emit_json(const CKIRV2Graph * input, CKIRV2LowerMode mode, const char * path)
ck_ir_v2_lower_graph
Forward
int ck_ir_v2_lower_graph(const CKIRV2Graph * input, CKIRV2LowerMode mode, CKIRV2Graph * output, CKMemPlan * plan)
ck_ir_v2_lower_mode_from_string
Forward
int ck_ir_v2_lower_mode_from_string(const char * name, CKIRV2LowerMode * out_mode)
ck_ir_v2_lower_mode_name
Forward
const char * ck_ir_v2_lower_mode_name(CKIRV2LowerMode mode)
ck_ir_v2_lower_node_enabled
Forward
int ck_ir_v2_lower_node_enabled(const CKIRV2Node * node, CKIRV2LowerMode mode)
ck_ir_v2_lower_strdup
Forward
char * ck_ir_v2_lower_strdup(const char * s)
ck_ir_v2_mem_arena_name
Forward
const char * ck_ir_v2_mem_arena_name(CKMemArenaKind arena)
ck_ir_v2_next_object
Forward
const char * ck_ir_v2_next_object(const char * cur, const char * end, const char ** obj_start, const char ** obj_end)
ck_ir_v2_parse_bindings
Forward
int ck_ir_v2_parse_bindings(const char * obj_start, const char * obj_end, CKIRV2Graph * graph, CKIRV2Node * node)
ck_ir_v2_parse_bool
Forward
int ck_ir_v2_parse_bool(const char * json, const char * key, const char * end, int * out_val)
ck_ir_v2_parse_buffers
Forward
int ck_ir_v2_parse_buffers(const char * json, const char * end, CKIRV2Graph * graph)
ck_ir_v2_parse_dim_kind
Forward
CKDimKind ck_ir_v2_parse_dim_kind(const char * obj_start, const char * obj_end)
ck_ir_v2_parse_dtype
Forward
CKDataType ck_ir_v2_parse_dtype(const char * s)
ck_ir_v2_parse_float
Forward
int ck_ir_v2_parse_float(const char * json, const char * key, const char * end, float * out_val)
ck_ir_v2_parse_int
Forward
int ck_ir_v2_parse_int(const char * json, const char * key, const char * end, int * out_val)
ck_ir_v2_parse_json
Forward
int ck_ir_v2_parse_json(const char * path, CKIRV2Graph * graph)
ck_ir_v2_parse_nodes
Forward
int ck_ir_v2_parse_nodes(const char * json, const char * end, CKIRV2Graph * graph)
ck_ir_v2_parse_role
Forward
CKBufferRole ck_ir_v2_parse_role(const char * s)
ck_ir_v2_parse_scope
Forward
CKBufferScope ck_ir_v2_parse_scope(const char * s)
ck_ir_v2_parse_shape
Forward
int ck_ir_v2_parse_shape(const char * obj_start, const char * obj_end, CKDimToken * shape_out)
ck_ir_v2_parse_string
Forward
int ck_ir_v2_parse_string(const char * start, const char * end, char ** out_str)
ck_ir_v2_parse_string_field
Forward
int ck_ir_v2_parse_string_field(const char * json, const char * key, const char * end, char ** out_str)
ck_ir_v2_resolve_align
Forward
void ck_ir_v2_resolve_align(const CKModelConfig * cfg, size_t alignment_bytes, CKIRV2AlignInfo * align)
ck_ir_v2_resolve_dim_value
Forward
size_t ck_ir_v2_resolve_dim_value(const CKModelConfig * cfg, const CKIRV2AlignInfo * align, CKDimKind dim, int tokens_override)
ck_ir_v2_role_name
Forward
const char * ck_ir_v2_role_name(CKBufferRole role)
ck_ir_v2_scope_name
Forward
const char * ck_ir_v2_scope_name(CKBufferScope scope)
ck_ir_v2_select_kernel
Forward
const char * ck_ir_v2_select_kernel(const CKKernelSpec * spec, CKDataType dtype, int backward)
ck_ir_v2_serialize_json
Forward
int ck_ir_v2_serialize_json(const CKIRV2Graph * graph, const char * path)
ck_ir_v2_serialize_json_internal
Forward
int ck_ir_v2_serialize_json_internal(const CKIRV2Graph * graph, const CKMemPlan * plan, const char * mode, int tokens_override, int base_context_window, const char * path)
ck_ir_v2_serialize_json_with_plan
Forward
int ck_ir_v2_serialize_json_with_plan(const CKIRV2Graph * graph, const CKMemPlan * plan, const char * mode, int tokens_override, int base_context_window, const char * path)
ck_ir_v2_skip_string
Forward
const char * ck_ir_v2_skip_string(const char * cur, const char * end)
ck_ir_v2_skip_ws
Forward
const char * ck_ir_v2_skip_ws(const char * cur, const char * end)
ck_ir_v2_strdup
Forward
char * ck_ir_v2_strdup(const char * s)
ck_ir_validate_supported
Forward
int ck_ir_validate_supported(const CKIRGraph * graph)
ck_layer_debug_enabled
Forward
int ck_layer_debug_enabled(void)
ck_load_weights_manifest_v4
Forward
int ck_load_weights_manifest_v4(void * base, const char * weights_path, const char * manifest_path)
Load BUMPWGT4 weights into a v4 model buffer using a manifest map.
ck_mem_plan_build_inference
Forward
int ck_mem_plan_build_inference(const CKIRV2Graph * graph, CKMemPlan * plan, size_t alignment_bytes)
ck_mem_plan_build_inference_with_tokens
Forward
int ck_mem_plan_build_inference_with_tokens(const CKIRV2Graph * graph, CKMemPlan * plan, size_t alignment_bytes, int tokens_override)
ck_mem_plan_build_training
Forward
int ck_mem_plan_build_training(const CKIRV2Graph * graph, CKMemPlan * plan, size_t alignment_bytes)
ck_mem_plan_build_training_with_tokens
Forward
int ck_mem_plan_build_training_with_tokens(const CKIRV2Graph * graph, CKMemPlan * plan, size_t alignment_bytes, int tokens_override)
ck_mem_plan_free
Forward
void ck_mem_plan_free(CKMemPlan * plan)
ck_memory_allocate
Forward
int ck_memory_allocate(CKModel * model, int use_hugepages)
Allocate the planned memory.
ck_memory_free
Forward
void ck_memory_free(CKModel * model)
Free the model memory.
ck_memory_plan
Forward
size_t ck_memory_plan(const CKSectionConfig * sections, int num_sections, int mode, uint32_t fusion_flags, CKModel * out_model)
Plan memory layout for a model.
ck_metrics_cleanup
Forward
void ck_metrics_cleanup(void)
Cleanup and free resources
ck_metrics_create_context
Forward
CKMetricsContext * ck_metrics_create_context(void)
ck_metrics_ctx_end
Forward
void ck_metrics_ctx_end(CKMetricsContext * ctx, const char * status)
ck_metrics_ctx_init
Forward
bool ck_metrics_ctx_init(CKMetricsContext * ctx, const char * run_id, const char * endpoint, CKMetricsMode mode)
ck_metrics_ctx_log_f
Forward
void ck_metrics_ctx_log_f(CKMetricsContext * ctx, const char * name, double value)
ck_metrics_ctx_log_i
Forward
void ck_metrics_ctx_log_i(CKMetricsContext * ctx, const char * name, int64_t value)
ck_metrics_ctx_step
Forward
void ck_metrics_ctx_step(CKMetricsContext * ctx, int64_t step)
ck_metrics_destroy_context
Forward
void ck_metrics_destroy_context(CKMetricsContext * ctx)
ck_metrics_end
Forward
void ck_metrics_end(const char * status)
End the training run
ck_metrics_generate_run_id
Forward
void ck_metrics_generate_run_id(char * buffer, size_t size)
Generate a unique run ID based on current timestamp
ck_metrics_get_memory_mb
Forward
int64_t ck_metrics_get_memory_mb(void)
Get current memory usage in MB (platform-specific)
ck_metrics_init
Forward
bool ck_metrics_init(const char * run_id, const char * endpoint, CKMetricsMode mode)
Initialize metrics logging
ck_metrics_init_full
Forward
bool ck_metrics_init_full(const char * run_id, const char * endpoint, CKMetricsMode mode, const char * model, const char * dataset, int batch_size, double lr, int max_steps)
Initialize with full configuration
ck_metrics_log_f
Forward
void ck_metrics_log_f(const char * name, double value)
Log a float metric (e.g., loss, learning rate)
ck_metrics_log_i
Forward
void ck_metrics_log_i(const char * name, int64_t value)
Log an integer metric (e.g., step, tokens_per_sec)
ck_metrics_log_s
Forward
void ck_metrics_log_s(const char * name, const char * value)
Log a string metric (e.g., phase, status)
ck_metrics_step
Forward
void ck_metrics_step(int64_t step)
Flush metrics for the current step and advance to next step Call this at the end of each training step
ck_metrics_timestamp
Forward
double ck_metrics_timestamp(void)
Get current timestamp in seconds with microsecond precision
ck_min
Forward
int ck_min(int a, int b)
ck_min_i
Forward
int ck_min_i(int a, int b)
ck_model_allocate
Forward
int ck_model_allocate(CKModel * model, int hugepage_mode)
Allocate the planned memory. hugepage_mode 0=normal, 1=2MB hugepages, 2=1GB hugepages 0 on success, -1 on failure
ck_model_config_from_hf_json
Forward
int ck_model_config_from_hf_json(const char * path, CKModelConfig * cfg)
Parse a HuggingFace-style config.json into CKModelConfig.
ck_model_create
Forward
void * ck_model_create(void)
Create and allocate model memory. Returns opaque model pointer, or NULL on failure.
ck_model_decode
Forward
void ck_model_decode(void * model, const int * token, int token_index)
Decode single token at position token_index. Used for autoregressive generation.
ck_model_forward
Forward
void ck_model_forward(void * model, const int * tokens, int num_tokens)
Forward pass (prefill) - process multiple tokens. Used for initial prompt processing.
ck_model_free
Forward
void ck_model_free(void * model)
Free model memory.
ck_model_get_base
Forward
void * ck_model_get_base(void * model)
Get model base pointer (for weight loading).
ck_model_get_config
Forward
const CKModelConfig * ck_model_get_config(void)
Get model configuration (dimensions, sizes, etc.) This is available before allocation.
ck_model_get_logits
Forward
float * ck_model_get_logits(void * model)
Get pointer to output logits buffer. Size is vocab_size floats.
ck_model_get_total_bytes
Forward
size_t ck_model_get_total_bytes(void * model)
Get total model size in bytes.
ck_model_load_weights
Forward
int ck_model_load_weights(void * model, const char * bump_path)
Load weights from BUMP file into model. Returns 0 on success, -1 on failure.
ck_model_load_weights_flat
Forward
int ck_model_load_weights_flat(TransformerModel * m, const char * path)
Load weights from a single flat binary file into model->memory_base.
ck_model_plan
Forward
size_t ck_model_plan(CKModel * model, const CKSectionConfig * configs, int num_sections, int training_enabled, uint32_t fusion_flags)
Plan memory layout for complete model. Returns total bytes needed.
ck_model_verify_canaries
Forward
int ck_model_verify_canaries(void * model)
Verify memory canaries (debug). Returns number of corrupted canaries (0 = OK).
ck_murmurhash3
Forward
uint32_t ck_murmurhash3(const char * key, uint32_t len, uint32_t seed)
MurmurHash3-32bit hash function (original HPC_Embeddings version).
ck_murmurhash3_128
Forward
void ck_murmurhash3_128(const void * key, size_t len, uint32_t seed, uint64_t * out1, uint64_t * out2)
MurmurHash3-128bit hash function (produces two 64-bit values).
ck_murmurhash3_32
Forward
uint32_t ck_murmurhash3_32(const void * key, size_t len, uint32_t seed)
MurmurHash3-32bit hash function (alternative name).
ck_murmurhash3_str
Forward
uint32_t ck_murmurhash3_str(const char * key, uint32_t seed)
ck_murmurhash3_strn
Forward
uint32_t ck_murmurhash3_strn(const char * str, size_t len, uint32_t seed)
MurmurHash3-32bit for string with length.
ck_nearest_int
Forward
int ck_nearest_int(float fval)
ck_nearest_int_fused
Forward
int ck_nearest_int_fused(float fval)
ck_op_name
Forward
const char * ck_op_name(CKOpType op)
ck_op_supported
Forward
int ck_op_supported(CKOpType op)
ck_parse_env_int
Forward
int ck_parse_env_int(const char * name)
ck_plan_step_enabled
Forward
int ck_plan_step_enabled(const CKPlanStep * step, const CKIRGraph * cfg)
ck_pool_alloc
Forward
void * ck_pool_alloc(CKMemPool * pool, size_t size)
ck_pool_free
Forward
void ck_pool_free(CKMemPool * pool)
ck_pool_init
Forward
void ck_pool_init(CKMemPool * pool)
ck_pool_strdup
Forward
char * ck_pool_strdup(CKMemPool * pool, const char * s, int len)
ck_prefill_forward
Forward
int ck_prefill_forward(const void * weights, const int32_t * tokens, int n_tokens, float * hidden_out, void * kv_cache, int kv_pos)
Forward pass computation
ck_q8_0_outproj_enabled
Forward
int ck_q8_0_outproj_enabled(void)
ck_q8k_activations_enabled
Forward
int ck_q8k_activations_enabled(void)
ck_qkv_project_head_major
Forward
void ck_qkv_project_head_major(const float * input, const float * wq, const float * bq, const float * wk, const float * bk, const float * wv, const float * bv, float * q, float * k, float * v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
ck_qkv_project_head_major_backward
Backward
void ck_qkv_project_head_major_backward(const float * d_q, const float * d_k, const float * d_v, const float * input, const float * wq, const float * bq, const float * wk, const float * bk, const float * wv, const float * bv, float * d_input, float * d_wq, float * d_bq, float * d_wk, float * d_bk, float * d_wv, float * d_bv, float * scratch, int tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim, int num_threads)
Backward pass / gradient computation
ck_qkv_project_head_major_q4_k
Forward
void ck_qkv_project_head_major_q4_k(const float * input, const void * wq, const float * bq, const void * wk, const float * bk, const void * wv, const float * bv, float * q, float * k, float * v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
ck_qkv_project_head_major_q4_k_q8_k
Forward
void ck_qkv_project_head_major_q4_k_q8_k(const float * input, const void * wq, const float * bq, const void * wk, const float * bk, const void * wv, const float * bv, float * q, float * k, float * v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
ck_qkv_project_head_major_quant
Forward
void ck_qkv_project_head_major_quant(const float * input, const void * wq, const float * bq, CKDataType wq_dtype, const void * wk, const float * bk, CKDataType wk_dtype, const void * wv, const float * bv, CKDataType wv_dtype, float * q, float * k, float * v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
ck_qkv_project_head_major_ref
Forward
void ck_qkv_project_head_major_ref(const float * input, const float * wq, const float * bq, const float * wk, const float * bk, const float * wv, const float * bv, float * q, float * k, float * v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
ck_qkv_project_head_major_token
Forward
void ck_qkv_project_head_major_token(const float * input_row, const float * wq, const float * bq, const float * wk, const float * bk, const float * wv, const float * bv, float * q_token, float * k_token, float * v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
ck_qkv_project_head_major_token_q4_k
Forward
void ck_qkv_project_head_major_token_q4_k(const float * input_row, const void * wq, const float * bq, const void * wk, const float * bk, const void * wv, const float * bv, float * q_token, float * k_token, float * v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
ck_qkv_project_head_major_token_q4_k_q8_k
Forward
void ck_qkv_project_head_major_token_q4_k_q8_k(const block_q8_K * input_q8, const void * wq, const float * bq, const void * wk, const float * bk, const void * wv, const float * bv, float * q_token, float * k_token, float * v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
ck_qkv_project_head_major_token_quant
Forward
void ck_qkv_project_head_major_token_quant(const float * input_row, const void * wq, const float * bq, CKDataType wq_dtype, const void * wk, const float * bk, CKDataType wk_dtype, const void * wv, const float * bv, CKDataType wv_dtype, float * q_token, float * k_token, float * v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)
ck_quant_block_size
Forward
size_t ck_quant_block_size(int type)
Get the block size (number of weights per block) for a quant type.
ck_quant_row_size
Forward
size_t ck_quant_row_size(int type, int64_t n_elements)
Calculate total bytes needed for n_elements with given quant type.
ck_quant_type_size
Forward
size_t ck_quant_type_size(int type)
Get the byte size per block for a quant type.
ck_residual_add_backward
Backward
void ck_residual_add_backward(const float * d_out, float * d_a, float * d_b, int tokens, int aligned_embed_dim)
Backward pass / gradient computation
ck_residual_add_token_major
Forward
void ck_residual_add_token_major(const float * a, const float * b, float * out, int tokens, int aligned_embed_dim)
ck_round_nearest
Forward
int ck_round_nearest(float v)
Round to nearest int, half away from zero (matches quantize_row_q8_0)
ck_scale_f32_to_f16
Forward
void ck_scale_f32_to_f16(const float * src, float scale, uint16_t * dst, int n)
Scale FP32 array and store as FP16: dst = scale * src.
ck_section_config_init
Forward
void ck_section_config_init(CKSectionConfig * config, size_t simd_align)
Initialize section config with computed alignments.
ck_section_plan
Forward
size_t ck_section_plan(CKSection * section, const CKSectionConfig * config, int training_enabled, size_t base_offset)
Plan memory layout for a single section. Returns bytes needed for this section.
ck_set_num_threads
Forward
void ck_set_num_threads(int num_threads)
ck_set_strict_parity
Forward
void ck_set_strict_parity(int enabled)
ck_strict_parity_enabled
Forward
int ck_strict_parity_enabled(void)
ck_test_dequant_q4_0
Forward
void ck_test_dequant_q4_0(const void * src, float * dst, int n)
Dequantize Q4_0 data to FP32.
ck_test_dequant_q4_k
Forward
void ck_test_dequant_q4_k(const void * src, float * dst, int n)
Dequantize Q4_K data to FP32.
ck_test_dequant_q5_1
Forward
void ck_test_dequant_q5_1(const void * src, float * dst, int n)
Dequantize Q5_1 data to FP32.
ck_test_dequant_q6_k
Forward
void ck_test_dequant_q6_k(const void * src, float * dst, int n)
Dequantize Q6_K data to FP32.
ck_test_geglu
Forward
void ck_test_geglu(const float * x, float * out, int n_tokens, int dim)
Test GeGLU activation.
ck_test_geglu_backward
Backward
void ck_test_geglu_backward(const float * x, const float * d_out, float * d_x, int n_tokens, int dim)
Test GeGLU backward.
ck_test_gemv_q4_k
Forward
void ck_test_gemv_q4_k(const void * weight_q4k, const float * input_f32, float * output, int cols)
Q4_K GEMV - dot product of quantized weights and FP32 input.
ck_test_gemv_q5_0
Forward
void ck_test_gemv_q5_0(const void * weight_q5_0, const float * input_f32, float * output, int rows, int cols)
Q5_0 GEMV - matrix-vector multiply with Q5_0 weights.
ck_test_gemv_q5_0_q8_0
Forward
void ck_test_gemv_q5_0_q8_0(const void * weight_q5_0, const float * input_f32, float * output, int rows, int cols)
Q5_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.
ck_test_gemv_q5_1
Forward
void ck_test_gemv_q5_1(const void * weight_q5_1, const float * input_f32, float * output, int rows, int cols)
Q5_1 GEMV - matrix-vector multiply with Q5_1 weights (32-element blocks)
ck_test_gemv_q5_k
Forward
void ck_test_gemv_q5_k(const void * weight_q5_k, const float * input_f32, float * output, int rows, int cols)
Q5_K GEMV - matrix-vector multiply with Q5_K weights (256-element super-blocks)
ck_test_gemv_q6_k
Forward
void ck_test_gemv_q6_k(const void * weight_q6k, const float * input_f32, float * output, int cols)
Q6_K GEMV.
ck_test_gemv_q8_0
Forward
void ck_test_gemv_q8_0(const void * weight_q8_0, const float * input_f32, float * output, int rows, int cols)
Q8_0 GEMV - matrix-vector multiply with Q8_0 weights.
ck_test_gemv_q8_0_q8_0
Forward
void ck_test_gemv_q8_0_q8_0(const void * weight_q8_0, const float * input_f32, float * output, int rows, int cols)
Q8_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.
ck_test_quantize_q8_k
Forward
void ck_test_quantize_q8_k(const float * src, void * dst, int n)
Quantize FP32 to Q8_K (for activations)
ck_test_vec_dot_q5_0_q8_0
Forward
void ck_test_vec_dot_q5_0_q8_0(const void * weight_q5_0, const void * input_q8_0, float * output, int cols)
Direct Q5_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)
ck_test_vec_dot_q8_0_q8_0
Forward
void ck_test_vec_dot_q8_0_q8_0(const void * weight_q8_0, const void * input_q8_0, float * output, int cols)
Direct Q8_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)
ck_threadpool_barrier
Forward
void ck_threadpool_barrier(ck_threadpool_t * pool)
Barrier synchronization within a dispatched work function.
ck_threadpool_create
Forward
ck_threadpool_t * ck_threadpool_create(int n_threads)
Create a thread pool with n_threads total threads. Thread 0 is the calling (main) thread; n_threads-1 workers are spawned.
ck_threadpool_destroy
Forward
void ck_threadpool_destroy(ck_threadpool_t * pool)
Destroy the thread pool. Signals all workers to exit and joins them. Safe to call with NULL.
ck_threadpool_dispatch
Forward
void ck_threadpool_dispatch(ck_threadpool_t * pool, ck_work_fn_t fn, void * args)
Dispatch work to all threads and wait for completion.
ck_threadpool_global
Forward
ck_threadpool_t * ck_threadpool_global(void)
Get or create the global thread pool. Thread-safe (uses pthread_once internally). Uses ck_get_num_threads() for auto-detection.
ck_threadpool_global_destroy
Forward
void ck_threadpool_global_destroy(void)
Destroy the global thread pool. Called during engine shutdown.
ck_threadpool_init
Forward
void ck_threadpool_init(void)
Initialize the global thread pool. Called once during engine startup (e.g., from ck_model_init). Uses ck_get_num_threads() for thread count (respects CK_NUM_THREADS env).
ck_threadpool_n_threads
Forward
int ck_threadpool_n_threads(const ck_threadpool_t * pool)
Get total thread count (including main thread)
ck_threadpool_pause
Forward
void ck_threadpool_pause(ck_threadpool_t * pool)
Pause workers — they sleep on condvar (0% CPU). Call between batches or during interactive waiting. Workers wake on next dispatch or resume.
ck_threadpool_resume
Forward
void ck_threadpool_resume(ck_threadpool_t * pool)
Resume workers — transition from sleep to spin-wait. Call before starting a new batch of work.
ck_threadpool_shutdown
Forward
void ck_threadpool_shutdown(void)
Shut down the global thread pool. Called during engine teardown. Workers are joined and freed.
ck_threadpool_thread_id
Forward
int ck_threadpool_thread_id(const ck_threadpool_t * pool)
Get thread index for current thread (0 = main, -1 if not in pool)
ck_tokenizer_add_merge
Forward
int ck_tokenizer_add_merge(CKTokenizer * tok, int32_t left, int32_t right, int32_t merged)
ck_tokenizer_add_special_token
Forward
int ck_tokenizer_add_special_token(CKTokenizer * tok, const char * name, int32_t id)
Add special token (UNK, BOS, EOS, PAD, MASK).
ck_tokenizer_add_token
Forward
int32_t ck_tokenizer_add_token(CKTokenizer * tok, const char * token, int len)
ck_tokenizer_create
Forward
CKTokenizer * ck_tokenizer_create(CKTokenizerType type)
ck_tokenizer_create_bpe
Forward
CKTokenizer * ck_tokenizer_create_bpe(void)
Create tokenizer with default BPE config.
ck_tokenizer_create_spm
Forward
CKTokenizer * ck_tokenizer_create_spm(void)
Create tokenizer with default SPM config.
ck_tokenizer_create_wordpiece
Forward
CKTokenizer * ck_tokenizer_create_wordpiece(void)
Create tokenizer with default WordPiece config.
ck_tokenizer_decode
Forward
int ck_tokenizer_decode(const CKTokenizer * tok, const int32_t * ids, int num_ids, char * text, int max_len)
ck_tokenizer_detect_space_prefix_style
Forward
CKSpacePrefixStyle ck_tokenizer_detect_space_prefix_style(CKTokenizer * tok)
ck_tokenizer_encode
Forward
int ck_tokenizer_encode(const CKTokenizer * tok, const char * text, int text_len, int32_t * ids, int max_ids)
ck_tokenizer_encode_spm_impl
Forward
int ck_tokenizer_encode_spm_impl(const CKTokenizer * tok, const char * text, int text_len, int32_t * ids, int max_ids)
ck_tokenizer_encode_spm_llama_impl
Forward
int ck_tokenizer_encode_spm_llama_impl(const CKTokenizer * tok, const char * text, int text_len, int32_t * ids, int max_ids)
ck_tokenizer_encode_tokens
Forward
int ck_tokenizer_encode_tokens(const CKTokenizer * tok, const char * text, int text_len, const char ** out_tokens, int max_tokens)
Encode and return tokens as array of strings.
ck_tokenizer_encode_with_special
Forward
int ck_tokenizer_encode_with_special(CKTokenizer * tok, const char * text, int text_len, int32_t * ids, int max_ids, bool add_special)
Encode with special token handling.
ck_tokenizer_free
Forward
void ck_tokenizer_free(CKTokenizer * tok)
ck_tokenizer_hash
Forward
uint32_t ck_tokenizer_hash(const char * key, size_t len)
ck_tokenizer_hash_str
Forward
uint32_t ck_tokenizer_hash_str(const char * key)
ck_tokenizer_hash_table_clear
Forward
void ck_tokenizer_hash_table_clear(CKTokenizerHashTable * table, bool free_values)
Clear all entries (but keep bucket array).
ck_tokenizer_hash_table_contains
Forward
bool ck_tokenizer_hash_table_contains(CKTokenizerHashTable * table, const char * key)
Check if key exists.
ck_tokenizer_hash_table_count
Forward
size_t ck_tokenizer_hash_table_count(CKTokenizerHashTable * table)
Get the number of entries.
ck_tokenizer_hash_table_create
Forward
CKTokenizerHashTable * ck_tokenizer_hash_table_create(size_t bucket_count)
Create a hash table.
ck_tokenizer_hash_table_delete
Forward
int ck_tokenizer_hash_table_delete(CKTokenizerHashTable * table, const char * key, bool free_value)
Delete a key.
ck_tokenizer_hash_table_free
Forward
void ck_tokenizer_hash_table_free(CKTokenizerHashTable * table, bool free_values)
Free a hash table.
ck_tokenizer_hash_table_insert
Forward
int ck_tokenizer_hash_table_insert(CKTokenizerHashTable * table, const char * key, void * value)
Insert a key-value pair.
ck_tokenizer_hash_table_iterate
Forward
int ck_tokenizer_hash_table_iterate(CKTokenizerHashTable * table, CKTokenizerHashCallback callback, void * user_data)
ck_tokenizer_hash_table_keys
Forward
size_t ck_tokenizer_hash_table_keys(CKTokenizerHashTable * table, const char ** out_keys, size_t max_keys)
Get all keys as an array.
ck_tokenizer_hash_table_lookup
Forward
void * ck_tokenizer_hash_table_lookup(CKTokenizerHashTable * table, const char * key)
Look up a key.
ck_tokenizer_hash_table_lookup_avx
Forward
void * ck_tokenizer_hash_table_lookup_avx(CKTokenizerHashTable * table, const char * key)
ck_tokenizer_id_to_token
Forward
const char * ck_tokenizer_id_to_token(const CKTokenizer * tok, int32_t id)
ck_tokenizer_init
Forward
int ck_tokenizer_init(CKTokenizer * tok)
ck_tokenizer_load
Forward
int ck_tokenizer_load(CKTokenizer * tok, const char * path)
ck_tokenizer_load_binary
Forward
int ck_tokenizer_load_binary(CKTokenizer * tok, int vocab_size, const int32_t * offsets, const char * strings, int num_merges, const int32_t * merges)
Load vocabulary from memory-mapped binary data.
ck_tokenizer_load_binary_with_scores
Forward
int ck_tokenizer_load_binary_with_scores(CKTokenizer * tok, int vocab_size, const int32_t * offsets, const char * strings, const float * scores, const uint8_t * types, int num_merges, const int32_t * merges)
Load vocabulary from memory-mapped binary data with scores and types.
ck_tokenizer_load_gguf
Forward
int ck_tokenizer_load_gguf(CKTokenizer * tok, const char * path)
Load vocabulary from GGUF file.
ck_tokenizer_load_json
Forward
int ck_tokenizer_load_json(CKTokenizer * tok, const char * path)
Load vocabulary from JSON file (HuggingFace format).
ck_tokenizer_load_merges
Forward
int ck_tokenizer_load_merges(CKTokenizer * tok, const char * path)
Load BPE merges from text file.
ck_tokenizer_load_text
Forward
int ck_tokenizer_load_text(CKTokenizer * tok, const char * path)
Load vocabulary from text file (one token per line).
ck_tokenizer_lookup
Forward
int32_t ck_tokenizer_lookup(const CKTokenizer * tok, const char * token, int len)
ck_tokenizer_lookup_exact
Forward
int32_t ck_tokenizer_lookup_exact(const CKTokenizer * tok, const char * token)
ck_tokenizer_lookup_exact_n
Forward
int32_t ck_tokenizer_lookup_exact_n(const CKTokenizer * tok, const char * text, int text_len)
ck_tokenizer_lookup_merge
Forward
int ck_tokenizer_lookup_merge(const CKTokenizer * tok, int32_t left, int32_t right)
ck_tokenizer_mempool_alloc
Forward
void * ck_tokenizer_mempool_alloc(CKTokenizerMemPool * pool, size_t size)
Allocate from pool.
ck_tokenizer_mempool_alloc_aligned
Forward
void * ck_tokenizer_mempool_alloc_aligned(CKTokenizerMemPool * pool, size_t size, size_t align)
Allocate aligned memory from pool.
ck_tokenizer_mempool_alloc_count
Forward
size_t ck_tokenizer_mempool_alloc_count(CKTokenizerMemPool * pool)
Get allocation count.
ck_tokenizer_mempool_available
Forward
size_t ck_tokenizer_mempool_available(CKTokenizerMemPool * pool)
Get available bytes in pool.
ck_tokenizer_mempool_free
Forward
void ck_tokenizer_mempool_free(CKTokenizerMemPool * pool)
Free a memory pool.
ck_tokenizer_mempool_init
Forward
int ck_tokenizer_mempool_init(CKTokenizerMemPool * pool, size_t size)
Initialize a memory pool.
ck_tokenizer_mempool_reset
Forward
void ck_tokenizer_mempool_reset(CKTokenizerMemPool * pool)
Reset pool (mark all memory as free).
ck_tokenizer_mempool_strdup
Forward
char * ck_tokenizer_mempool_strdup(CKTokenizerMemPool * pool, const char * str)
Allocate and copy string (strdup equivalent).
ck_tokenizer_mempool_strndup
Forward
char * ck_tokenizer_mempool_strndup(CKTokenizerMemPool * pool, const char * str, int len)
Allocate and copy string with length.
ck_tokenizer_mempool_used
Forward
size_t ck_tokenizer_mempool_used(CKTokenizerMemPool * pool)
Get used bytes in pool.
ck_tokenizer_reset
Forward
void ck_tokenizer_reset(CKTokenizer * tok)
ck_tokenizer_set_add_bos_eos
Forward
void ck_tokenizer_set_add_bos_eos(CKTokenizer * tok, bool add_bos, bool add_eos)
ck_tokenizer_set_add_space_prefix
Forward
void ck_tokenizer_set_add_space_prefix(CKTokenizer * tok, bool add_space_prefix)
ck_tokenizer_set_space_prefix_style
Forward
void ck_tokenizer_set_space_prefix_style(CKTokenizer * tok, CKSpacePrefixStyle style)
ck_tokenizer_set_special_ids
Forward
void ck_tokenizer_set_special_ids(CKTokenizer * tok, int32_t unk, int32_t bos, int32_t eos, int32_t pad, int32_t mask)
ck_tokenizer_set_spm_mode
Forward
void ck_tokenizer_set_spm_mode(CKTokenizer * tok, CKSpmMode spm_mode)
ck_tokenizer_set_use_trie
Forward
void ck_tokenizer_set_use_trie(CKTokenizer * tok, bool use_trie)
ck_tokenizer_utf8_normalize_nfc
Forward
size_t ck_tokenizer_utf8_normalize_nfc(const char * src, size_t src_len, char * dst, size_t dst_size)
ck_tokenizer_vocab_size
Forward
int ck_tokenizer_vocab_size(const CKTokenizer * tok)
ck_trie_clear
Forward
void ck_trie_clear(CKTrie * trie)
ck_trie_create
Forward
CKTrie * ck_trie_create(size_t max_nodes)
ck_trie_find_longest
Forward
int32_t ck_trie_find_longest(const CKTrie * trie, const char * text, size_t text_len, size_t start_pos, size_t * match_len)
ck_trie_free
Forward
void ck_trie_free(CKTrie * trie)
ck_trie_has_prefix
Forward
bool ck_trie_has_prefix(const CKTrie * trie, const char * text, size_t text_len, size_t pos)
ck_trie_insert
Forward
int ck_trie_insert(CKTrie * trie, const char * token, int32_t token_id, bool is_special, int32_t priority)
ck_trie_node_count
Forward
size_t ck_trie_node_count(const CKTrie * trie)
ck_true_bpe_add_merge
Forward
int ck_true_bpe_add_merge(CKTrueBPE * bpe, int32_t left_id, int32_t right_id, int32_t merged_id, int32_t priority)
ck_true_bpe_add_merge_by_tokens
Forward
int ck_true_bpe_add_merge_by_tokens(CKTrueBPE * bpe, const char * left, const char * right, int32_t priority)
ck_true_bpe_add_special_token
Forward
int ck_true_bpe_add_special_token(CKTrueBPE * bpe, const char * token, int32_t id)
ck_true_bpe_add_token
Forward
int ck_true_bpe_add_token(CKTrueBPE * bpe, const char * token, int32_t id, float score)
ck_true_bpe_create
Forward
CKTrueBPE * ck_true_bpe_create(void)
ck_true_bpe_decode
Forward
int ck_true_bpe_decode(const CKTrueBPE * bpe, const int32_t * ids, int num_ids, char * text, int max_len)
ck_true_bpe_detect_space_style
Forward
CKSpacePrefixStyle ck_true_bpe_detect_space_style(CKTrueBPE * bpe)
ck_true_bpe_encode
Forward
int ck_true_bpe_encode(CKTrueBPE * bpe, const char * text, int text_len, int32_t * ids, int max_ids)
ck_true_bpe_free
Forward
void ck_true_bpe_free(CKTrueBPE * bpe)
ck_true_bpe_id_to_token
Forward
const char * ck_true_bpe_id_to_token(const CKTrueBPE * bpe, int32_t id)
ck_true_bpe_load_binary
Forward
int ck_true_bpe_load_binary(CKTrueBPE * bpe, int vocab_size, const int32_t * offsets, const char * strings, int num_merges, const int32_t * merges)
ck_true_bpe_lookup
Forward
int32_t ck_true_bpe_lookup(const CKTrueBPE * bpe, const char * token)
ck_true_bpe_num_merges
Forward
int32_t ck_true_bpe_num_merges(const CKTrueBPE * bpe)
ck_true_bpe_set_config
Forward
void ck_true_bpe_set_config(CKTrueBPE * bpe, const CKBPEConfig * config)
ck_true_bpe_set_special_ids
Forward
void ck_true_bpe_set_special_ids(CKTrueBPE * bpe, int32_t unk, int32_t bos, int32_t eos, int32_t pad)
ck_true_bpe_vocab_size
Forward
size_t ck_true_bpe_vocab_size(const CKTrueBPE * bpe)
ck_utf8_byte_to_offset
Forward
size_t ck_utf8_byte_to_offset(const char * str, size_t len, size_t byte_offset)
Get the character index from byte offset.
ck_utf8_char_length
Forward
int ck_utf8_char_length(unsigned char c)
Get the length of a UTF-8 character from its first byte.
ck_utf8_count_chars
Forward
size_t ck_utf8_count_chars(const char * str, size_t len)
Count UTF-8 characters in a string.
ck_utf8_decode_2
Forward
uint32_t ck_utf8_decode_2(const char * s)
Get 2-byte UTF-8 sequence value.
ck_utf8_decode_3
Forward
uint32_t ck_utf8_decode_3(const char * s)
Get 3-byte UTF-8 sequence value.
ck_utf8_decode_4
Forward
uint32_t ck_utf8_decode_4(const char * s)
Get 4-byte UTF-8 sequence value.
ck_utf8_first_byte
Forward
unsigned char ck_utf8_first_byte(const char * s)
Get the first byte of a UTF-8 character.
ck_utf8_from_cp
Forward
int ck_utf8_from_cp(uint32_t cp, char * out)
Write a Unicode code point as UTF-8.
ck_utf8_is_continuation
Forward
int ck_utf8_is_continuation(unsigned char c)
Get the continuation byte mask and value.
ck_utf8_is_valid
Forward
bool ck_utf8_is_valid(const char * str, size_t len)
Check if a byte sequence is valid UTF-8.
ck_utf8_is_whitespace
Forward
bool ck_utf8_is_whitespace(uint32_t cp)
Check if character is whitespace (Unicode White_Space property).
ck_utf8_next_char
Forward
int32_t ck_utf8_next_char(const char ** str, int * out_len)
Get next UTF-8 character, return its code point.
ck_utf8_normalize_nfc
Forward
size_t ck_utf8_normalize_nfc(const char * src, size_t src_len, char * dst, size_t dst_size)
Normalize UTF-8 string (Unicode normalization form NFC).
ck_utf8_offset_to_byte
Forward
size_t ck_utf8_offset_to_byte(const char * str, size_t len, size_t n)
Get the byte offset of the N-th character.
ck_utf8_validate
Forward
size_t ck_utf8_validate(const char * str, size_t len)
Validate a UTF-8 string.
ck_weight_dtype_expr
Forward
const char * ck_weight_dtype_expr(const CKBufferSpec * spec)
ckernel_backend_native
CKMathBackend ckernel_backend_native(void)
Obtain the built-in native backend (single-node CPU, C + intrinsics).
clamp_int8
Forward
int8_t clamp_int8(float value)
compute_align
Forward
CKV2AlignInfo compute_align(const CKModelConfig * cfg)
compute_rms_scale
Forward
float compute_rms_scale(const float * x, int n, float eps)
compute_rms_scale_internal
Forward
float compute_rms_scale_internal(const float * x, int n, float eps)
convert_bf16_tensor_to_buf
Forward
void convert_bf16_tensor_to_buf(const uint16_t * src, float * dst, size_t count)
convert_f16_to_f32
Forward
void convert_f16_to_f32(float * dst, const uint16_t * src, size_t count)
Convert FP16 tensor to FP32.
convert_f32_to_f16
Forward
void convert_f32_to_f16(uint16_t * dst, const float * src, size_t count)
Convert FP32 tensor to FP16.
convert_float_to_int4
Forward
void convert_float_to_int4(const float * src, uint8_t * dst, size_t count)
convert_float_to_int8
Forward
void convert_float_to_int8(const float * src, int8_t * dst, size_t count)
convert_int4_to_float
Forward
void convert_int4_to_float(const uint8_t * src, float * dst, size_t count)
convert_int8_to_float
Forward
void convert_int8_to_float(const int8_t * src, float * dst, size_t count)
count_set_bits
Forward
int count_set_bits(const char * hex_mask)
cpu_features_init
Forward
void cpu_features_init(void)
create_entry
Forward
CKTokenizerHashEntry * create_entry(const char * key, const void * value, size_t value_size)
create_node
Forward
CKTrieNode * create_node(void)
decode_bpe_token
Forward
int decode_bpe_token(const char * token, char * out, int max)
Decode GPT-2 byte-level BPE representation back to actual bytes.
decode_int4
Forward
int8_t decode_int4(uint8_t packed, int index)
decode_layer_parallel
Forward
void decode_layer_parallel(float * hidden, const void * ln1_weight, const void * ln2_weight, const void * WQ, const void * WK, const void * WV, const void * WO, const void * W_gate, const void * W_up, const void * W_down, float * k_cache, float * v_cache, int token_index, float * scratch, int embed_dim, int intermediate, int H, int H_kv, int head_dim, int max_seq, float eps, int num_threads)
Process one transformer layer in parallel.
dequant_q4_0_block
Forward
void dequant_q4_0_block(const block_q4_0 * block, float * output)
Dequantize a single Q4_0 block to FP32.
dequant_q4_0_row
Forward
void dequant_q4_0_row(const void * src, float * dst, size_t n_elements)
Dequantize Q4_0 row (multiple blocks)
dequant_q4_1_block
Forward
void dequant_q4_1_block(const block_q4_1 * block, float * output)
Dequantize a single Q4_1 block to FP32.
dequant_q4_1_row
Forward
void dequant_q4_1_row(const void * src, float * dst, size_t n_elements)
Dequantize Q4_1 row (multiple blocks)
dequant_q4_k_block
Forward
void dequant_q4_k_block(const block_q4_K * block, float * output)
Dequantize a single Q4_K block to FP32.
dequant_q4_k_row
Forward
void dequant_q4_k_row(const void * src, float * dst, size_t n_elements)
Dequantize Q4_K row (multiple blocks)
dequant_q5_0_block
Forward
void dequant_q5_0_block(const block_q5_0 * block, float * output)
Dequantize a single Q5_0 block to FP32.
dequant_q5_0_row
Forward
void dequant_q5_0_row(const void * src, float * dst, size_t n_elements)
Dequantize Q5_0 row (multiple blocks)
dequant_q5_1_block
Forward
void dequant_q5_1_block(const block_q5_1 * block, float * output)
Dequantize a single Q5_1 block to FP32.
dequant_q5_1_row
Forward
void dequant_q5_1_row(const void * src, float * dst, size_t n_elements)
Dequantize Q5_1 row (multiple blocks)
dequant_q6_k_block
Forward
void dequant_q6_k_block(const block_q6_K * block, float * output)
Dequantize a single Q6_K block to FP32.
dequant_q6_k_row
Forward
void dequant_q6_k_row(const void * src, float * dst, size_t n_elements)
Dequantize Q6_K row (multiple blocks)
dequant_q8_0_block
Forward
void dequant_q8_0_block(const block_q8_0 * block, float * output)
Dequantize a single Q8_0 block to FP32.
dequant_q8_0_row
Forward
void dequant_q8_0_row(const void * src, float * dst, size_t n_elements)
Dequantize Q8_0 row (multiple blocks)
dequant_row
Forward
void dequant_row(CKDataType dtype, const void * src, float * dst, size_t n_elements)
Dequantize a row of quantized data to FP32.
detach_allocation
Forward
ck_huge_alloc_entry_t * detach_allocation(void * ptr)
detect_chat_template
Forward
ChatTemplateType detect_chat_template(const char * model_name)
detect_physical_cores
Forward
int detect_physical_cores(void)
dot_f16
Forward
float dot_f16(const uint16_t * w_f16, const float * x, int K)
dot_fp32_q5_0_block
Forward
float dot_fp32_q5_0_block(const float * x, const block_q5_0 * block)
Compute dot product of FP32 input with Q5_0 weight block, with online Q8 quantization.
dot_fp32_q8_0_block
Forward
float dot_fp32_q8_0_block(const float * x, const block_q8_0 * block)
Compute dot product of FP32 input with Q8_0 weight block, with online Q8 quantization.
dot_q4_0
Forward
float dot_q4_0(const void * w_q4_0, const float * x, int K)
dot_q4_1
Forward
float dot_q4_1(const void * w_q4_1, const float * x, int K)
dot_q4_k
Forward
float dot_q4_k(const void * w_q4k, const float * x, int K)
Compute dot product of Q4_K row with FP32 vector.
dot_q4_k_q8_k_ref
Forward
float dot_q4_k_q8_k_ref(const block_q4_K * w, const block_q8_K * x, int k)
dot_q5_0
Forward
float dot_q5_0(const void * w_q5_0, const float * x, int K)
dot_q5_0_q8_k_32_sse
Forward
float dot_q5_0_q8_k_32_sse(const block_q5_0 * bw, const block_q8_K * ba, int q8_offset)
dot_q5_1
Forward
float dot_q5_1(const void * w_q5_1, const float * x, int K)
dot_q6_k_q8_k_256_sse
Forward
float dot_q6_k_q8_k_256_sse(const block_q6_K * bw, const block_q8_K * ba)
SSE Optimized dot product for Q6_K x Q8_K Q6_K layout: ql: 128 bytes (low 4 bits) qh: 64 bytes (high 2 bits) scales: 16 bytes (int8 scales) d: fp16 super-scale
dot_q6_k_q8_k_ref
Forward
float dot_q6_k_q8_k_ref(const block_q6_K * w, const block_q8_K * x, int K)
Scalar dot product for Q6_K x Q8_K.
dot_q6_k_ref
Forward
float dot_q6_k_ref(const block_q6_K * w, const float * x, int K)
dot_q8_0
Forward
float dot_q8_0(const void * w_q8_0, const float * x, int K)
embedding_backward
Backward
void embedding_backward(const int32_t * token_ids, int token_count, const float * d_output, float * d_token_embeddings, float * d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
Backward pass / gradient computation
embedding_backward_bf16
Backward
void embedding_backward_bf16(const int32_t * token_ids, int token_count, const uint16_t * d_output, uint16_t * d_token_embeddings, uint16_t * d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
Backward pass / gradient computation
embedding_forward
Forward
void embedding_forward(const int32_t * token_ids, int token_count, int vocab_size, const float * token_embeddings, const float * pos_embeddings, float * output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
Forward pass computation
embedding_forward_bf16
Forward
void embedding_forward_bf16(const int32_t * token_ids, int token_count, int vocab_size, const uint16_t * token_embeddings, const uint16_t * pos_embeddings, uint16_t * output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
Forward pass computation
embedding_forward_q4_k
Forward
void embedding_forward_q4_k(const int32_t * token_ids, int token_count, int vocab_size, const void * token_embeddings, const float * pos_embeddings, float * output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
Forward pass computation
embedding_forward_q6_k
Forward
void embedding_forward_q6_k(const int32_t * token_ids, int token_count, int vocab_size, const void * token_embeddings, const float * pos_embeddings, float * output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
Forward pass computation
embedding_forward_q8_0
Forward
void embedding_forward_q8_0(const int32_t * token_ids, int token_count, int vocab_size, const void * token_embeddings, const float * pos_embeddings, float * output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
Forward pass computation
emit_body_fields
Forward
void emit_body_fields(FILE * out, const CKIRV2Graph * graph, CKBufferRole role_filter, int activation_group)
emit_body_values
Forward
size_t emit_body_values(FILE * out, const CKIRV2Graph * graph, const CKMemPlan * plan, CKBufferRole role_filter, int activation_group)
emit_bump_bytes_assignment
Forward
void emit_bump_bytes_assignment(FILE * out, const char * indent, const char * struct_prefix, const char * name, const CKDimToken * shape)
emit_bump_bytes_assignment_weight_dtype
Forward
void emit_bump_bytes_assignment_weight_dtype(FILE * out, const char * indent, const char * struct_prefix, const char * name, const CKDimToken * shape, const char * dtype_expr)
emit_dim_expr
Forward
void emit_dim_expr(FILE * out, CKDimKind dim)
emit_footer_fields
Forward
void emit_footer_fields(FILE * out, const CKIRV2Graph * graph, CKBufferRole role_filter, int activation_group)
emit_footer_values
Forward
void emit_footer_values(FILE * out, const CKIRV2Graph * graph, const CKMemPlan * plan, CKBufferRole role_filter, int activation_group, size_t * offset)
emit_global_aliases_to_layer
Forward
void emit_global_aliases_to_layer(FILE * out)
emit_global_allocations
Forward
void emit_global_allocations(FILE * out)
emit_global_offset_fields
Forward
void emit_global_offset_fields(FILE * out)
emit_header_fields
Forward
void emit_header_fields(FILE * out, const CKIRV2Graph * graph, CKBufferRole role_filter, int activation_group)
emit_header_values
Forward
void emit_header_values(FILE * out, const CKIRV2Graph * graph, const CKMemPlan * plan, CKBufferRole role_filter, int activation_group, size_t * offset)
emit_kernel_manifest
Forward
int emit_kernel_manifest(const CKIRGraph * forward, const char * runtime_path)
emit_layer_allocations
Forward
void emit_layer_allocations(FILE * out)
emit_layer_offsets_struct
Forward
void emit_layer_offsets_struct(FILE * out)
emit_library_api
Forward
void emit_library_api(FILE * out, const CKIRGraph * forward)
emit_model_struct
Forward
void emit_model_struct(FILE * out)
emit_offset_field
Forward
void emit_offset_field(FILE * out, const char * name)
emit_plan_sources
Forward
int emit_plan_sources(FILE * f, const CKPlanStep * plan, size_t plan_count, const CKIRGraph * cfg, const char ** seen, size_t * seen_count, size_t seen_cap)
emit_runtime_preamble
Forward
int emit_runtime_preamble(FILE * out)
emit_schedule_block
Forward
void emit_schedule_block(FILE * out, const CKIRV2Graph * graph, const char * func_name, const char * label, const char * runtime_sym)
emit_sgd_update
Forward
void emit_sgd_update(FILE * out)
emit_shape_expr
Forward
void emit_shape_expr(FILE * out, const CKDimToken * shape)
emit_span_field
Forward
void emit_span_field(FILE * out, const char * label)
emit_span_value
Forward
void emit_span_value(FILE * out, const char * label, size_t offset, size_t size, int comma)
emit_training_conditional_assignment
Forward
void emit_training_conditional_assignment(FILE * out, const char * indent, const char * struct_prefix, const char * name, const CKDimToken * shape)
emit_unique_source
Forward
int emit_unique_source(FILE * f, const char * path, const char ** seen, size_t * seen_count, size_t seen_cap)
emit_zero_grad
Forward
void emit_zero_grad(FILE * out)
encode_chunk
Forward
int encode_chunk(CKTrueBPE * bpe, const char * chunk, int chunk_len, int32_t * ids, int max_ids, CKBPETokenList * list)
encode_int4_nibble
Forward
uint8_t encode_int4_nibble(int8_t value)
encode_text_segment
Forward
int encode_text_segment(CKTrueBPE * bpe, const char * text, int text_len, int32_t * ids, int max_ids)
engine_thread_func
Forward
void * engine_thread_func(void * arg)
eos_is_potential_prefix
Forward
bool eos_is_potential_prefix(const char * token)
Check if token might be start of EOS pattern.
eos_pattern_init
Forward
void eos_pattern_init(ChatTemplateType tmpl)
eos_pattern_process
Forward
bool eos_pattern_process(const char * token_text, char * out_buf, size_t * out_len, void(*)(char *, size_t *, const char *) output_fn, ChatTemplateType tmpl)
Process a token for EOS pattern detection.
eos_pattern_reset
Forward
void eos_pattern_reset(void)
find_best_merge
Forward
int find_best_merge(const CKTrueBPE * bpe, const CKBPETokenList * list, size_t * best_pos, const CKBPEMerge ** best_merge)
find_buffer_by_name
Forward
int find_buffer_by_name(const CKIRV2Graph * graph, const char * name)
find_longest_match
Forward
int32_t find_longest_match(const CKTokenizer * tok, const char * text, size_t text_len, size_t pos, size_t * match_len)
find_longest_match_hash
Forward
int32_t find_longest_match_hash(const CKTokenizer * tok, const char * text, size_t text_len, size_t pos, size_t * match_len)
find_longest_match_trie
Forward
int32_t find_longest_match_trie(const CKTokenizer * tok, const char * text, size_t text_len, size_t pos, size_t * match_len)
find_model_in_cache
bool find_model_in_cache(const char * model_name, char * lib_out, char * weights_out, size_t out_size)
find_object_range
Forward
int find_object_range(const char * json, const char * key, const char ** out_start, size_t * out_len)
flatten_head_major
Forward
void flatten_head_major(const float * attn_out, float * dst, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
float_tensor_to_bf16
Forward
void float_tensor_to_bf16(const float * src, uint16_t * dst, size_t count)
float_to_bf16
Forward
uint16_t float_to_bf16(float f)
format_bandwidth
Forward
const char * format_bandwidth(float bw_gbs, char * buf, size_t buf_size)
format_size
Forward
const char * format_size(uint64_t size_mb, char * buf, size_t buf_size)
free_entry
Forward
void free_entry(CKTokenizerHashEntry * entry, bool free_value)
fused_kernels_compute_kv_tile
Forward
int fused_kernels_compute_kv_tile(int l1_size, int head_dim, int bytes_per_elem)
Compute optimal KV tile size for flash attention.
fused_kernels_report_stats
Forward
void fused_kernels_report_stats(int hidden, int num_layers, int seq_len)
Report memory savings from mega-fusion.
fused_kernels_validate_constraints
Forward
int fused_kernels_validate_constraints(int l1_size, int head_dim, int kv_tile_size, int bytes_per_elem)
Validate cache constraints for fusion.
fused_output_projection_residual
Forward
void fused_output_projection_residual(float * output, const float * o_all, const float * W_o, const float * b_o, const float * residual, int hidden, int num_heads, int head_dim)
Fused output projection with residual add.
geglu_backward_fp32
Backward
void geglu_backward_fp32(const float * x, const float * d_out, float * d_x, int n_tokens, int dim)
GeGLU backward pass (fp32) Testtest_geglu.py::TestGeGLU::test_geglu_backward_fp32
geglu_forward_bf16
Forward
void geglu_forward_bf16(const uint16_t * x, uint16_t * out, int tokens, int dim, float * scratch)
GeGLU forward pass (bf16) Testtest_geglu.py::TestGeGLU::test_geglu_forward_bf16
geglu_forward_fp32
Forward
void geglu_forward_fp32(const float * x, float * out, int tokens, int dim)
GeGLU forward pass (fp32) Testtest_geglu.py::TestGeGLU::test_geglu_forward_fp32
gemv_f16
Forward
void gemv_f16(float * y, const uint16_t * W, const float * x, int M, int K)
Auto-dispatch GEMV based on available SIMD.
gemv_f16_backward
Backward
void gemv_f16_backward(float * dX, const uint16_t * W, const float * dY, int M, int K)
Auto-dispatch backward.
gemv_f16_backward_ref
Backward
void gemv_f16_backward_ref(float * dX, const uint16_t * W, const float * dY, int M, int K)
Backward pass: compute input gradient (scalar reference)
gemv_f16_ref
Forward
void gemv_f16_ref(float * y, const uint16_t * W, const float * x, int M, int K)
Matrix-vector multiply with FP16 weights (scalar reference)
gemv_fused_q5_0_bias
Forward
void gemv_fused_q5_0_bias(float * y, const void * W, const float * x, const float * bias, int M, int K)
gemv_fused_q5_0_bias_dispatch
Forward
void gemv_fused_q5_0_bias_dispatch(float * y, const void * W, const float * x, const float * bias, int M, int K)
gemv_fused_q5_0_bias_parallel_omp
Forward
void gemv_fused_q5_0_bias_parallel_omp(float * y, const void * W, const float * x, const float * bias, int M, int K)
gemv_fused_q8_0_bias
Forward
void gemv_fused_q8_0_bias(float * y, const void * W, const float * x, const float * bias, int M, int K)
gemv_fused_q8_0_bias_dispatch
Forward
void gemv_fused_q8_0_bias_dispatch(float * y, const void * W, const float * x, const float * bias, int M, int K)
gemv_nt_q5_0_head_major_output
Forward
void gemv_nt_q5_0_head_major_output(float * output, const float * attn_out, const void * wo, const float * bias, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection reading head-major attention output (Q5_0 weights)
gemv_q4_0
Forward
void gemv_q4_0(float * y, const void * W, const float * x, int M, int K)
Auto-dispatch GEMV.
gemv_q4_0_backward
Backward
void gemv_q4_0_backward(float * dX, const void * W, const float * dY, int M, int K)
Auto-dispatch backward.
gemv_q4_0_backward_ref
Backward
void gemv_q4_0_backward_ref(float * dX, const void * W, const float * dY, int M, int K)
Backward pass: compute input gradient.
gemv_q4_0_ref
Forward
void gemv_q4_0_ref(float * y, const void * W, const float * x, int M, int K)
Matrix-vector multiply with Q4_0 weights (scalar reference)
gemv_q4_1
Forward
void gemv_q4_1(float * y, const void * W, const float * x, int M, int K)
Auto-dispatch GEMV.
gemv_q4_1_backward
Backward
void gemv_q4_1_backward(float * dX, const void * W, const float * dY, int M, int K)
Auto-dispatch backward.
gemv_q4_1_backward_ref
Backward
void gemv_q4_1_backward_ref(float * dX, const void * W, const float * dY, int M, int K)
Backward pass: compute input gradient.
gemv_q4_1_ref
Forward
void gemv_q4_1_ref(float * y, const void * W, const float * x, int M, int K)
Matrix-vector multiply with Q4_1 weights (scalar reference)
gemv_q4_k
Forward
void gemv_q4_k(float * y, const void * W, const float * x, int M, int K)
Auto-dispatch GEMV based on available SIMD.
gemv_q4_k_backward
Backward
void gemv_q4_k_backward(float * dX, const void * W, const float * dY, int M, int K)
Auto-dispatch backward.
gemv_q4_k_backward_ref
Backward
void gemv_q4_k_backward_ref(float * dX, const void * W, const float * dY, int M, int K)
Backward pass: compute input gradient (scalar reference)
gemv_q4_k_q8_k
Forward
void gemv_q4_k_q8_k(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q4_k_q8_k_amx
Forward
void gemv_q4_k_q8_k_amx(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q4_k_q8_k_avx
Forward
void gemv_q4_k_q8_k_avx(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q4_k_q8_k_avx2
Forward
void gemv_q4_k_q8_k_avx2(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q4_k_q8_k_parallel
Forward
void gemv_q4_k_q8_k_parallel(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)
gemv_q4_k_q8_k_parallel_simd
Forward
void gemv_q4_k_q8_k_parallel_simd(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)
gemv_q4_k_q8_k_ref
Forward
void gemv_q4_k_q8_k_ref(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q4_k_q8_k_sse
Forward
void gemv_q4_k_q8_k_sse(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q4_k_q8_k_vnni
Forward
void gemv_q4_k_q8_k_vnni(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q4_k_ref
Forward
void gemv_q4_k_ref(float * y, const void * W, const float * x, int M, int K)
Matrix-vector multiply with Q4_K weights (scalar reference)
gemv_q5_0
Forward
void gemv_q5_0(float * y, const void * W, const float * x, int M, int K)
Auto-dispatch GEMV for Q5_0 weights based on CPU features.
gemv_q5_0_backward
Backward
void gemv_q5_0_backward(float * dX, const void * W, const float * dY, int M, int K)
Auto-dispatch backward.
gemv_q5_0_backward_ref
Backward
void gemv_q5_0_backward_ref(float * dX, const void * W, const float * dY, int M, int K)
Backward pass: compute input gradient.
gemv_q5_0_from_fp32
Forward
void gemv_q5_0_from_fp32(float * out, const void * W_q5_0, const float * x_fp32, const float * bias, int M, int K, block_q8_0 * x_q8_scratch)
gemv_q5_0_parallel
Forward
void gemv_q5_0_parallel(float * y, const void * W, const float * x, int M, int K, int ith, int nth)
Parallel reference GEMV for Q5_0 × FP32.
gemv_q5_0_parallel_simd
Forward
void gemv_q5_0_parallel_simd(float * y, const void * W, const float * x, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q5_0 × FP32 with prefetching.
gemv_q5_0_q8_0
Forward
void gemv_q5_0_q8_0(float * y, const void * W, const void * x_q8, int M, int K)
Matrix-vector multiply with Q5_0 weights and Q8_0 input.
gemv_q5_0_q8_0_parallel_omp
Forward
void gemv_q5_0_q8_0_parallel_omp(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q5_0_q8_0_parallel_simd
Forward
void gemv_q5_0_q8_0_parallel_simd(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q5_0 x Q8_0 with prefetching.
gemv_q5_0_ref
Forward
void gemv_q5_0_ref(float * y, const void * W, const float * x, int M, int K)
Matrix-vector multiply with Q5_0 weights (scalar reference)
gemv_q5_1
Forward
void gemv_q5_1(float * y, const void * W, const float * x, int M, int K)
Auto-dispatch GEMV.
gemv_q5_1_backward
Backward
void gemv_q5_1_backward(float * dX, const void * W, const float * dY, int M, int K)
Auto-dispatch backward.
gemv_q5_1_backward_ref
Backward
void gemv_q5_1_backward_ref(float * dX, const void * W, const float * dY, int M, int K)
Backward pass: compute input gradient.
gemv_q5_1_ref
Forward
void gemv_q5_1_ref(float * y, const void * W, const float * x, int M, int K)
Matrix-vector multiply with Q5_1 weights (scalar reference)
gemv_q5_k
Forward
void gemv_q5_k(float * y, const void * W, const float * x, int M, int K)
gemv_q5_k_ref
Forward
void gemv_q5_k_ref(float * y, const void * W, const float * x, int M, int K)
gemv_q6_k
Forward
void gemv_q6_k(float * y, const void * W, const float * x, int M, int K)
gemv_q6_k_q8_k
Forward
void gemv_q6_k_q8_k(float * y, const void * W, const void * x_q8, int M, int K)
GEMV: y = W @ x where W is Q6_K and x is Q8_K.
gemv_q6_k_q8_k_avx
Forward
void gemv_q6_k_q8_k_avx(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q6_k_q8_k_avx2
Forward
void gemv_q6_k_q8_k_avx2(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q6_k_q8_k_avx512
Forward
void gemv_q6_k_q8_k_avx512(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q6_k_q8_k_avx512_vbmi
Forward
void gemv_q6_k_q8_k_avx512_vbmi(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q6_k_q8_k_parallel
Forward
void gemv_q6_k_q8_k_parallel(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)
Parallel reference GEMV for Q6_K × Q8_K.
gemv_q6_k_q8_k_parallel_simd
Forward
void gemv_q6_k_q8_k_parallel_simd(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q6_K × Q8_K.
gemv_q6_k_q8_k_ref
Forward
void gemv_q6_k_q8_k_ref(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q6_k_q8_k_sse
Forward
void gemv_q6_k_q8_k_sse(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q8_0
Forward
void gemv_q8_0(float * y, const void * W, const float * x, int M, int K)
Auto-dispatch GEMV for Q8_0 weights based on CPU features.
gemv_q8_0_backward
Backward
void gemv_q8_0_backward(float * dX, const void * W, const float * dY, int M, int K)
Auto-dispatch backward.
gemv_q8_0_backward_ref
Backward
void gemv_q8_0_backward_ref(float * dX, const void * W, const float * dY, int M, int K)
Backward pass: compute input gradient (scalar reference)
gemv_q8_0_from_fp32
Forward
void gemv_q8_0_from_fp32(float * out, const void * W_q8_0, const float * x_fp32, const float * bias, int M, int K, block_q8_0 * x_q8_scratch)
gemv_q8_0_parallel_simd
Forward
void gemv_q8_0_parallel_simd(float * y, const void * W, const float * x, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q8_0 weights x FP32 input with prefetching.
gemv_q8_0_q8_0
Forward
void gemv_q8_0_q8_0(float * y, const void * W, const void * x_q8, int M, int K)
Matrix-vector multiply with Q8_0 weights and Q8_0 input.
gemv_q8_0_q8_0_parallel
Forward
void gemv_q8_0_q8_0_parallel(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)
Parallel reference GEMV for Q8_0 x Q8_0.
gemv_q8_0_q8_0_parallel_omp
Forward
void gemv_q8_0_q8_0_parallel_omp(float * y, const void * W, const void * x_q8, int M, int K)
gemv_q8_0_q8_0_parallel_simd
Forward
void gemv_q8_0_q8_0_parallel_simd(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)
Parallel SIMD GEMV for Q8_0 x Q8_0 with prefetching.
gemv_q8_0_ref
Forward
void gemv_q8_0_ref(float * y, const void * W, const float * x, int M, int K)
Matrix-vector multiply with Q8_0 weights (scalar reference)
get_cache_dir
const char * get_cache_dir(void)
get_cpu_info
Forward
const CPUInfo * get_cpu_info(void)
get_current_cpu
Forward
int get_current_cpu(void)
get_numa_node_for_cpu
Forward
int get_numa_node_for_cpu(int cpu)
get_optimal_decode_threads
Forward
int get_optimal_decode_threads(void)
get_q5_k_scale_min
Forward
void get_q5_k_scale_min(int j, const uint8_t * scales, uint8_t * scale, uint8_t * min)
get_time_ms
Forward
double get_time_ms(void)
global_pool_init
Forward
void global_pool_init(void)
gpt2_decode_byte
Forward
int gpt2_decode_byte(const unsigned char * s, int len)
gpt2_pretokenize
Forward
int gpt2_pretokenize(const char * text, int text_len, PretokChunk * chunks, int max_chunks)
gradient_accumulate_bf16
Forward
void gradient_accumulate_bf16(uint16_t * dst, const uint16_t * src, size_t numel)
Accumulate gradients: dst += src (bf16)
gradient_accumulate_f32
Forward
void gradient_accumulate_f32(float * dst, const float * src, size_t numel)
Accumulate gradients: dst += src (fp32)
gradient_clip_norm_bf16
Forward
float gradient_clip_norm_bf16(uint16_t * grad, size_t numel, float max_norm)
Clip gradient norm (bf16)
gradient_clip_norm_f32
Forward
float gradient_clip_norm_f32(float * grad, size_t numel, float max_norm)
Clip gradient norm (fp32)
gradient_scale_bf16
Forward
void gradient_scale_bf16(uint16_t * grad, size_t numel, float scale)
Scale gradients: grad *= scale (bf16)
gradient_scale_f32
Forward
void gradient_scale_f32(float * grad, size_t numel, float scale)
Scale gradients by a constant: grad *= scale (fp32)
handle_sigint
Forward
void handle_sigint(int sig)
has_cpu_flag
Forward
int has_cpu_flag(const char * flags, const char * flag)
hash_pair
Forward
uint32_t hash_pair(int32_t left, int32_t right)
hash_string
Forward
uint32_t hash_string(const char * s, int len)
hmax256_ps_fused
Forward
float hmax256_ps_fused(__m256 v)
hsum256_ps_fused
Forward
float hsum256_ps_fused(__m256 v)
hsum_epi32_sse
Forward
int32_t hsum_epi32_sse(__m128i v)
im2patch
Forward
void im2patch(const float * image, float * patches, int C, int H, int W, int P)
im2patch: Transforms an image into a sequence of flattened patches.
im2patch_bf16
Forward
void im2patch_bf16(const uint16_t * image, uint16_t * patches, int C, int H, int W, int P)
init_tokens_from_text
Forward
int init_tokens_from_text(CKTrueBPE * bpe, CKBPETokenList * list, const char * text, int text_len)
is_activation_role
Forward
int is_activation_role(CKBufferRole role)
is_bpe_digit
Forward
bool is_bpe_digit(const char * s, int len)
is_bpe_letter
Forward
bool is_bpe_letter(const char * s, int len)
is_bpe_newline
Forward
bool is_bpe_newline(const char * s, int len)
is_bpe_punct
Forward
bool is_bpe_punct(const char * s, int len)
is_digit
Forward
bool is_digit(unsigned char c)
is_eos_token
Forward
bool is_eos_token(const CLIOptions * opt, int token)
is_footer_global
Forward
int is_footer_global(const char * name)
is_gpt2_space
Forward
bool is_gpt2_space(const char * s, int len)
is_letter
Forward
bool is_letter(unsigned char c)
is_whitespace
Forward
bool is_whitespace(unsigned char c)
is_word_prefix_char
Forward
bool is_word_prefix_char(const char * s, int len)
json_match_char
Forward
int json_match_char(JSONParser * p, char c)
json_parse_int
Forward
int json_parse_int(JSONParser * p, int * out)
json_parse_string
Forward
int json_parse_string(JSONParser * p, char * buf, int max_len)
json_skip_value
Forward
void json_skip_value(JSONParser * p)
json_skip_whitespace
Forward
void json_skip_whitespace(JSONParser * p)
kv_cache_repack_head_major_inplace
void kv_cache_repack_head_major_inplace(float * buf, int num_heads, int tokens, int cache_capacity, int aligned_head_dim)
kv_cache_store
void kv_cache_store(float *__restrict kv_cache_k, float *__restrict kv_cache_v, const float *__restrict k, const float *__restrict v, int layer, int pos, int num_kv_heads, int head_dim, int max_seq_len)
kv_cache_write_head_major
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
layout_transformer_from_ir
Forward
void layout_transformer_from_ir(TransformerModel * m, const CKIRGraph * ir)
Compute a simple forward-only layout for TransformerModel based on: CKModelConfig (dims, heads, vocab, context) The IR graph structure (number of layers, op types)
list_available_models
Forward
void list_available_models(void)
load_eos_from_vocab_json
Forward
bool load_eos_from_vocab_json(const char * weights_path, CLIOptions * opt)
load_manifest
Forward
int load_manifest(const char * path, ManifestEntry ** entries, int * num_entries)
load_model_api
Forward
bool load_model_api(const char * lib_path, ModelAPI * api)
load_weights
Forward
int load_weights(QWEN2_DECODEModel * model, const char * bump_path, const char * manifest_path)
load_weights_from_bump
Forward
int load_weights_from_bump(void * model, const char * bump_path)
logits_copy_to_position
Forward
void logits_copy_to_position(const float *__restrict src, float *__restrict dst, int position, int vocab_size)
Copy logits to position-indexed location in output buffer.
main
Forward
int main(int argc, char ** argv)
map_forward_to_backward
Backward
CKOpType map_forward_to_backward(CKOpType op)
Backward pass / gradient computation
match_special_token
Forward
int match_special_token(const CKTrueBPE * bpe, const char * text, int text_len, int pos)
max_k_for_query
Forward
int max_k_for_query(int t_q, int T_q, int T_k)
mega_fuse_get_optimal_tiles
Forward
void mega_fuse_get_optimal_tiles(int * q_tile, int * kv_tile, int head_dim)
Get optimal tile sizes for current CPU.
mega_fuse_output_proj_residual
Forward
void mega_fuse_output_proj_residual(const float * attn_token, const float * wo, const float * bo, const float * residual, float * output, int embed_dim, int aligned_embed_dim, int num_heads, int head_dim, int aligned_head_dim)
mega_fuse_report_stats
Forward
void mega_fuse_report_stats(int hidden, int num_layers, int seq_len)
Report memory savings from mega-fusion.
merge_hash
Forward
size_t merge_hash(uint64_t key, size_t num_buckets)
merge_key
Forward
uint64_t merge_key(int32_t left_id, int32_t right_id)
merge_table_create
Forward
CKMergeTable * merge_table_create(size_t num_buckets)
merge_table_free
Forward
void merge_table_free(CKMergeTable * table)
merge_table_insert
Forward
int merge_table_insert(CKMergeTable * table, const CKBPEMerge * merge)
merge_table_lookup
Forward
const CKBPEMerge * merge_table_lookup(const CKMergeTable * table, int32_t left_id, int32_t right_id)
model_align_elems
Forward
int model_align_elems(int elems, int elem_bytes, int align_bytes)
model_decode
Forward
void model_decode(MODELModel * model, const int * token, int token_index)
model_decode_token
Forward
void model_decode_token(MODELModel * model, const int * token, int token_index)
model_forward
Forward
void model_forward(MODELModel * model, const int * tokens, int num_tokens)
Forward pass computation
model_forward_prefill_impl
Forward
void model_forward_prefill_impl(MODELModel * model, const int * tokens, int num_tokens)
Forward pass computation
model_layer_0_decode
Forward
void model_layer_0_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_0_prefill
Forward
void model_layer_0_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_10_decode
Forward
void model_layer_10_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_10_prefill
Forward
void model_layer_10_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_11_decode
Forward
void model_layer_11_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_11_prefill
Forward
void model_layer_11_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_12_decode
Forward
void model_layer_12_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_12_prefill
Forward
void model_layer_12_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_13_decode
Forward
void model_layer_13_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_13_prefill
Forward
void model_layer_13_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_14_decode
Forward
void model_layer_14_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_14_prefill
Forward
void model_layer_14_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_15_decode
Forward
void model_layer_15_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_15_prefill
Forward
void model_layer_15_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_16_decode
Forward
void model_layer_16_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_16_prefill
Forward
void model_layer_16_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_17_decode
Forward
void model_layer_17_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_17_prefill
Forward
void model_layer_17_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_18_decode
Forward
void model_layer_18_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_18_prefill
Forward
void model_layer_18_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_19_decode
Forward
void model_layer_19_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_19_prefill
Forward
void model_layer_19_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_1_decode
Forward
void model_layer_1_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_1_prefill
Forward
void model_layer_1_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_20_decode
Forward
void model_layer_20_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_20_prefill
Forward
void model_layer_20_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_21_decode
Forward
void model_layer_21_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_21_prefill
Forward
void model_layer_21_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_22_decode
Forward
void model_layer_22_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_22_prefill
Forward
void model_layer_22_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_23_decode
Forward
void model_layer_23_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_23_prefill
Forward
void model_layer_23_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_2_decode
Forward
void model_layer_2_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_2_prefill
Forward
void model_layer_2_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_3_decode
Forward
void model_layer_3_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_3_prefill
Forward
void model_layer_3_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_4_decode
Forward
void model_layer_4_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_4_prefill
Forward
void model_layer_4_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_5_decode
Forward
void model_layer_5_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_5_prefill
Forward
void model_layer_5_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_6_decode
Forward
void model_layer_6_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_6_prefill
Forward
void model_layer_6_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_7_decode
Forward
void model_layer_7_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_7_prefill
Forward
void model_layer_7_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_8_decode
Forward
void model_layer_8_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_8_prefill
Forward
void model_layer_8_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_9_decode
Forward
void model_layer_9_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_layer_9_prefill
Forward
void model_layer_9_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
model_model_allocate
Forward
int model_model_allocate(MODELModel * model)
model_model_free
Forward
void model_model_free(MODELModel * model)
model_residual_add_token_major
Forward
void model_residual_add_token_major(const float * a, const float * b, float * out, int tokens, int aligned_embed_dim)
model_verify_canaries
Forward
int model_verify_canaries(MODELModel * model)
moe_accumulate_expert_f32
Forward
void moe_accumulate_expert_f32(float * output, const float * expert_output, float routing_weight, int hidden_dim)
Accumulate expert output: output += routing_weight * expert_output.
op_name
Forward
const char * op_name(CKOpType op)
out_proj_head_major_q5_0_q8_0
Forward
void out_proj_head_major_q5_0_q8_0(const uint8_t * attn_q8, const void * wo, const float * bias, float * output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
out_proj_head_major_q8_0_q8_0
Forward
void out_proj_head_major_q8_0_q8_0(const uint8_t * attn_q8, const void * wo, const float * bias, float * output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)
output_append
Forward
void output_append(char * buf, size_t * len, const char * text)
output_flush
Forward
void output_flush(char * buf, size_t * len)
output_token
Forward
void output_token(char * buf, size_t * len, const char * token)
pack_a_panel
Forward
void pack_a_panel(const float * A, int lda, float * Ap, int mc, int kc, int mr)
pack_b_panel
Forward
void pack_b_panel(const float * B, int ldb, float * Bp, int kc, int nc, int nr)
parse_args
Forward
bool parse_args(int argc, char ** argv, CLIOptions * opt)
parse_eos_ids
Forward
bool parse_eos_ids(const char * arg, CLIOptions * opt)
parse_float_field_any
Forward
int parse_float_field_any(const char * json, size_t len, const char *const * keys, float * out_value)
parse_float_field_in_range
Forward
int parse_float_field_in_range(const char * json, size_t len, const char * key, float * out_value)
parse_int_field
Forward
int parse_int_field(const char * json, const char * key, int * out_value)
parse_int_field_any
Forward
int parse_int_field_any(const char * json, size_t len, const char *const * keys, int * out_value)
parse_int_field_in_range
Forward
int parse_int_field_in_range(const char * json, size_t len, const char * key, int * out_value)
parse_manifest_entry
Forward
bool parse_manifest_entry(const char * json, const char * name, size_t * offset, size_t * size)
parse_manifest_int
Forward
int parse_manifest_int(const char * json, const char * key)
parse_op
Forward
CKOpType parse_op(const char * s)
parse_u64
Forward
unsigned long long parse_u64(const char * s)
patch2im
Forward
void patch2im(const float * d_patches, float * d_image, int C, int H, int W, int P)
patch2im: Accumulates gradients from patches back into the image. (Backward pass)
patch2im_bf16
Forward
void patch2im_bf16(const uint16_t * d_patches, uint16_t * d_image, int C, int H, int W, int P)
pcie_bandwidth_gbs
Forward
float pcie_bandwidth_gbs(int gen, int width)
plan_size
Forward
size_t plan_size(const CKMemPlan * plan, int idx)
pool_new_block
Forward
CKPoolBlock * pool_new_block(size_t capacity)
preprocess_bpe_spaces
Forward
int preprocess_bpe_spaces(const char * text, int text_len, char * out, int out_max, CKSpacePrefixStyle style)
preprocess_spm_llama_text
Forward
int preprocess_spm_llama_text(const char * text, int text_len, char * out, int out_max, bool add_space_prefix)
preprocess_spm_text
Forward
int preprocess_spm_text(const char * text, int text_len, char * out, int out_max, bool add_space_prefix)
preprocess_text
Forward
int preprocess_text(const CKTrueBPE * bpe, const char * text, int text_len, char * out, int out_max)
print_banner
Forward
void print_banner(void)
print_cpu_info
Forward
void print_cpu_info(void)
print_header
Forward
void print_header(const char * title)
print_help
Forward
void print_help(const char * prog)
print_ok
Forward
void print_ok(const char * msg)
print_progress
Forward
void print_progress(int token_id, float token_per_sec)
print_section
Forward
void print_section(const char * title)
print_tree_item
Forward
void print_tree_item(int level, int is_last, const char * fmt, ...)
print_usage
Forward
void print_usage(const char * prog)
print_version
Forward
void print_version(void)
print_warning
Forward
void print_warning(const char * msg)
process_repl_command
Forward
bool process_repl_command(const char * line, CLIOptions * opt, ModelAPI * api)
qk_norm_forward
Forward
void qk_norm_forward(float * q, float * k, const float * q_gamma, const float * k_gamma, int num_heads, int num_kv_heads, int num_tokens, int head_dim, float eps)
Per-head RMSNorm on Q and K.
qkv_index
Forward
size_t qkv_index(int h, int t, int d, int num_tokens, int aligned_head_dim)
qkv_projection_parallel
Forward
void qkv_projection_parallel(const void * ln1_q8, const void * WQ, const void * WK, const void * WV, float * q_out, float * k_out, float * v_out, int H, int H_kv, int head_dim, int embed_dim, int num_threads)
Parallel Q/K/V projection for single token decode.
qkv_q8_0_dtype_supported
Forward
int qkv_q8_0_dtype_supported(CKDataType dt)
qkv_q8_k_dtype_supported
Forward
int qkv_q8_k_dtype_supported(CKDataType dt)
quantize_attn_out_head_major_q8_0
Forward
void quantize_attn_out_head_major_q8_0(const float * attn_out, uint8_t * dst, int tokens, int num_heads, int aligned_head_dim)
quantize_batch_q8_0
Forward
void quantize_batch_q8_0(const float * x, void * vy, int num_rows, int k)
Batch quantize FP32 to Q8_0 format (row-major output)
quantize_batch_q8_k
Forward
void quantize_batch_q8_k(const float * x, void * vy, int num_rows, int k)
Batch quantize FP32 to Q8_K format (row-major output)
quantize_row_q8_0
Forward
void quantize_row_q8_0(const float * x, void * vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)
quantize_row_q8_k
Forward
void quantize_row_q8_k(const float * x, void * vy, int k)
quantize_row_q8_k_ref
Forward
void quantize_row_q8_k_ref(const float * x, void * vy, int k)
quantize_row_q8_k_sse
Forward
void quantize_row_q8_k_sse(const float * x, void * vy, int k)
qwen2_0_5b_decode_align_elems
Forward
int qwen2_0_5b_decode_align_elems(int elems, int elem_bytes, int align_bytes)
qwen2_0_5b_decode_decode
Forward
void qwen2_0_5b_decode_decode(QWEN2_0_5B_DECODEModel * model, const int * token, int token_index)
qwen2_0_5b_decode_decode_token
Forward
void qwen2_0_5b_decode_decode_token(QWEN2_0_5B_DECODEModel * model, const int * token, int token_index)
qwen2_0_5b_decode_forward
Forward
void qwen2_0_5b_decode_forward(QWEN2_0_5B_DECODEModel * model, const int * tokens, int num_tokens)
Forward pass computation
qwen2_0_5b_decode_forward_prefill_impl
Forward
void qwen2_0_5b_decode_forward_prefill_impl(QWEN2_0_5B_DECODEModel * model, const int * tokens, int num_tokens)
Forward pass computation
qwen2_0_5b_decode_layer_0_decode
Forward
void qwen2_0_5b_decode_layer_0_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_0_prefill
Forward
void qwen2_0_5b_decode_layer_0_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_10_decode
Forward
void qwen2_0_5b_decode_layer_10_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_10_prefill
Forward
void qwen2_0_5b_decode_layer_10_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_11_decode
Forward
void qwen2_0_5b_decode_layer_11_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_11_prefill
Forward
void qwen2_0_5b_decode_layer_11_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_12_decode
Forward
void qwen2_0_5b_decode_layer_12_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_12_prefill
Forward
void qwen2_0_5b_decode_layer_12_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_13_decode
Forward
void qwen2_0_5b_decode_layer_13_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_13_prefill
Forward
void qwen2_0_5b_decode_layer_13_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_14_decode
Forward
void qwen2_0_5b_decode_layer_14_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_14_prefill
Forward
void qwen2_0_5b_decode_layer_14_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_15_decode
Forward
void qwen2_0_5b_decode_layer_15_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_15_prefill
Forward
void qwen2_0_5b_decode_layer_15_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_16_decode
Forward
void qwen2_0_5b_decode_layer_16_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_16_prefill
Forward
void qwen2_0_5b_decode_layer_16_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_17_decode
Forward
void qwen2_0_5b_decode_layer_17_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_17_prefill
Forward
void qwen2_0_5b_decode_layer_17_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_18_decode
Forward
void qwen2_0_5b_decode_layer_18_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_18_prefill
Forward
void qwen2_0_5b_decode_layer_18_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_19_decode
Forward
void qwen2_0_5b_decode_layer_19_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_19_prefill
Forward
void qwen2_0_5b_decode_layer_19_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_1_decode
Forward
void qwen2_0_5b_decode_layer_1_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_1_prefill
Forward
void qwen2_0_5b_decode_layer_1_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_20_decode
Forward
void qwen2_0_5b_decode_layer_20_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_20_prefill
Forward
void qwen2_0_5b_decode_layer_20_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_21_decode
Forward
void qwen2_0_5b_decode_layer_21_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_21_prefill
Forward
void qwen2_0_5b_decode_layer_21_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_22_decode
Forward
void qwen2_0_5b_decode_layer_22_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_22_prefill
Forward
void qwen2_0_5b_decode_layer_22_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_23_decode
Forward
void qwen2_0_5b_decode_layer_23_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_23_prefill
Forward
void qwen2_0_5b_decode_layer_23_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_2_decode
Forward
void qwen2_0_5b_decode_layer_2_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_2_prefill
Forward
void qwen2_0_5b_decode_layer_2_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_3_decode
Forward
void qwen2_0_5b_decode_layer_3_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_3_prefill
Forward
void qwen2_0_5b_decode_layer_3_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_4_decode
Forward
void qwen2_0_5b_decode_layer_4_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_4_prefill
Forward
void qwen2_0_5b_decode_layer_4_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_5_decode
Forward
void qwen2_0_5b_decode_layer_5_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_5_prefill
Forward
void qwen2_0_5b_decode_layer_5_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_6_decode
Forward
void qwen2_0_5b_decode_layer_6_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_6_prefill
Forward
void qwen2_0_5b_decode_layer_6_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_7_decode
Forward
void qwen2_0_5b_decode_layer_7_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_7_prefill
Forward
void qwen2_0_5b_decode_layer_7_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_8_decode
Forward
void qwen2_0_5b_decode_layer_8_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_8_prefill
Forward
void qwen2_0_5b_decode_layer_8_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_9_decode
Forward
void qwen2_0_5b_decode_layer_9_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_layer_9_prefill
Forward
void qwen2_0_5b_decode_layer_9_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)
qwen2_0_5b_decode_model_allocate
Forward
int qwen2_0_5b_decode_model_allocate(QWEN2_0_5B_DECODEModel * model)
qwen2_0_5b_decode_model_free
Forward
void qwen2_0_5b_decode_model_free(QWEN2_0_5B_DECODEModel * model)
qwen2_0_5b_decode_residual_add_token_major
Forward
void qwen2_0_5b_decode_residual_add_token_major(const float * a, const float * b, float * out, int tokens, int aligned_embed_dim)
qwen2_0_5b_decode_verify_canaries
Forward
int qwen2_0_5b_decode_verify_canaries(QWEN2_0_5B_DECODEModel * model)
read_file_int
Forward
int read_file_int(const char * path)
read_file_string
Forward
int read_file_string(const char * path, char * buf, size_t buf_size)
read_file_uint64
Forward
uint64_t read_file_uint64(const char * path)
read_floats
Forward
int read_floats(FILE * f, float * dst, size_t count)
read_manifest_entry
Forward
bool read_manifest_entry(const char * json_path, const char * entry_name, size_t * out_offset, size_t * out_size)
read_prompt_file
Forward
char * read_prompt_file(const char * path)
record_allocation
Forward
int record_allocation(void * ptr, size_t len, int was_mmap)
relu_backward
Backward
void relu_backward(const float * input, const float * d_output, float * d_input, size_t n)
Backward pass / gradient computation
relu_backward_bf16
Backward
void relu_backward_bf16(const uint16_t * input, const uint16_t * d_output, uint16_t * d_input, size_t n)
Backward pass / gradient computation
relu_forward
Forward
void relu_forward(const float * input, float * output, size_t n)
Forward pass computation
relu_forward_bf16
Forward
void relu_forward_bf16(const uint16_t * input, uint16_t * output, size_t n)
Forward pass computation
relu_forward_inplace
Forward
void relu_forward_inplace(float * data, size_t n)
Forward pass computation
relu_forward_inplace_bf16
Forward
void relu_forward_inplace_bf16(uint16_t * data, size_t n)
Forward pass computation
residual_add
Forward
void residual_add(float * residual, float * addend, int n)
residual_add_parallel
Forward
void residual_add_parallel(const float * a, const float * b, float * out, int n, int ith, int nth)
Single-token decode with parallel SIMD kernels.
resolve_dim
Forward
size_t resolve_dim(const CKModelConfig * cfg, const CKIRV2AlignInfo * align, CKDimKind kind, int tokens_override)
resolve_shape_elems
Forward
size_t resolve_shape_elems(const CKModelConfig * cfg, const CKIRV2AlignInfo * align, const CKDimToken * shape, int tokens_override)
resolve_symbol
Forward
bool resolve_symbol(void * handle, const char * name, void ** out_ptr, bool required)
run_benchmark
Forward
void run_benchmark(void * model, int num_tokens)
run_command
Forward
int run_command(const char * cmd, char * output, size_t output_size)
run_generation_test
Forward
void run_generation_test(void * model, int num_tokens)
run_inference
Forward
int run_inference(const char * bump_path, const char * manifest_path, const char * tokenizer_path, const char * prompt, int max_tokens, float temperature, int topk)
run_prompt
Forward
int run_prompt(ModelAPI * api, CKTrueBPE * tokenizer, CLIOptions * opt, const char * input)
sample_argmax
Forward
int sample_argmax(const float * logits, int vocab_size)
sample_token
Forward
int sample_token(float * logits, int vocab_size, float temp, int top_k)
sample_top_p
Forward
int sample_top_p(float * logits, int vocab_size, float temperature, float top_p)
sample_topk
Forward
int sample_topk(float * probs, int vocab_size, int topk)
scal_copy_f32
Forward
void scal_copy_f32(float * y, const float * x, float alpha, int n)
Scaled copy: y = alpha * x.
score_index
Forward
size_t score_index(int h, int i, int j, int aligned_context_window)
sgd_momentum_update_bf16
Forward
void sgd_momentum_update_bf16(const uint16_t * grad, uint16_t * weight, float * velocity, size_t numel, float lr, float momentum, float weight_decay)
SGD with momentum (bf16 weights/gradients)
sgd_momentum_update_f32
Forward
void sgd_momentum_update_f32(const float * grad, float * weight, float * velocity, size_t numel, float lr, float momentum, float weight_decay)
SGD with momentum optimizer update (fp32 version)
silu
Forward
void silu(float * x, int n)
silu_prefill
Forward
float silu_prefill(float x)
silu_scalar
Forward
float silu_scalar(float x)
simd_strcmp
Forward
int simd_strcmp(const char * s1, const char * s2)
simple_embedding
Forward
void simple_embedding(const int32_t * tokens, int num_tokens, const float * weight, float * output, int vocab_size, int embed_dim)
spm_build_byte_lookup
Forward
void spm_build_byte_lookup(CKTokenizer * tok, const char * strings, const int32_t * offsets, int vocab_size)
spm_count_unknown_run
Forward
int spm_count_unknown_run(const CKTokenizer * tok, const char * text, int text_len, size_t pos)
spm_encode_byte_fallback
Forward
int spm_encode_byte_fallback(const CKTokenizer * tok, const char * text, int text_len, int32_t * ids, int max_ids)
spm_find_candidates_at_pos
Forward
int spm_find_candidates_at_pos(const CKTokenizer * tok, const char * text, int text_len, size_t pos, int32_t * candidates, int max_candidates)
spm_get_byte_token
Forward
int32_t spm_get_byte_token(const CKTokenizer * tok, unsigned char byte_val)
spm_is_byte_token
Forward
bool spm_is_byte_token(const CKTokenizer * tok, int32_t token_id)
spm_llama_resegment_node
Forward
int spm_llama_resegment_node(const CKTokenizer * tok, const SpmLlamaNode * nodes, int node_id, int32_t * ids, int max_ids, int out_idx)
spm_token_allowed_in_dp
Forward
bool spm_token_allowed_in_dp(const CKTokenizer * tok, int32_t token_id)
spm_token_is_byte_format
Forward
bool spm_token_is_byte_format(const char * token)
starts_with
Forward
int starts_with(const char * s, const char * prefix)
token_list_append
Forward
int token_list_append(CKBPETokenList * list, const char * str, size_t len, int32_t id)
token_list_clear
Forward
void token_list_clear(CKBPETokenList * list)
token_list_create
Forward
CKBPETokenList * token_list_create(size_t initial_capacity)
token_list_free
Forward
void token_list_free(CKBPETokenList * list)
token_list_merge_at
Forward
int token_list_merge_at(CKBPETokenList * list, size_t pos, const char * merged_str, size_t merged_len, int32_t merged_id)
tokenize
Forward
int32_t * tokenize(const char * text, int * num_tokens)
topk_batched_f32
Forward
void topk_batched_f32(const float * scores, int num_tokens, int n_experts, int k, int * indices, float * weights)
Batched top-K selection for multiple tokens.
topk_f32
Forward
void topk_f32(const float * scores, int n, int k, int * indices, float * values)
Find top-K indices and values from a score vector.
topology_discover
Forward
int topology_discover(SystemTopology * topo)
topology_discover_affinity
Forward
int topology_discover_affinity(AffinityInfo * aff)
topology_discover_cache
int topology_discover_cache(CacheTopology * cache)
topology_discover_cpu
Forward
int topology_discover_cpu(CPUInfo * cpu)
topology_discover_memory
Forward
int topology_discover_memory(MemoryInfo * mem)
topology_discover_network
Forward
int topology_discover_network(NetworkTopology * net)
topology_discover_numa
Forward
int topology_discover_numa(NUMATopology * numa)
topology_discover_pcie
Forward
int topology_discover_pcie(PCIeTopology * pcie)
topology_estimate_channels_from_bandwidth
Forward
int topology_estimate_channels_from_bandwidth(float measured_bw_gbs, int memory_speed_mhz, const char * memory_type)
topology_estimate_memory_bandwidth
Forward
float topology_estimate_memory_bandwidth(const MemoryInfo * mem)
topology_estimate_network_training_time
Forward
float topology_estimate_network_training_time(const NetworkTopology * net, uint64_t model_size_mb)
topology_generate_recommendations
Forward
int topology_generate_recommendations(const SystemTopology * topo, RecommendationList * recs)
topology_measure_memory_bandwidth
Forward
float topology_measure_memory_bandwidth(void)
topology_measure_memory_bandwidth_ex
Forward
float topology_measure_memory_bandwidth_ex(int * numa_node_out, int * num_threads_out)
topology_print_affinity
Forward
void topology_print_affinity(const AffinityInfo * aff)
topology_print_cache
void topology_print_cache(const CacheTopology * cache, int logical_cores)
topology_print_cpu
Forward
void topology_print_cpu(const CPUInfo * cpu)
topology_print_distributed_potential
Forward
void topology_print_distributed_potential(const SystemTopology * topo)
topology_print_memory
Forward
void topology_print_memory(const MemoryInfo * mem)
topology_print_network
Forward
void topology_print_network(const NetworkTopology * net)
topology_print_numa
Forward
void topology_print_numa(const NUMATopology * numa, int sockets)
topology_print_pcie
Forward
void topology_print_pcie(const PCIeTopology * pcie)
topology_print_recommendations
Forward
void topology_print_recommendations(const RecommendationList * recs)
topology_print_summary
Forward
void topology_print_summary(const SystemTopology * topo)
trim_string
Forward
void trim_string(char * str)
unpack_q4_k_scales
Forward
void unpack_q4_k_scales(const uint8_t * scales, uint8_t * sc, uint8_t * m)
Unpack Q4_K sub-block scales and mins.
unpack_q5_k_scales
Forward
void unpack_q5_k_scales(const uint8_t * scales, uint8_t * sc, uint8_t * m)
Unpack Q5_K sub-block scales and mins.
utf8_char_len
Forward
int utf8_char_len(unsigned char c)
utf8_len
Forward
int utf8_len(unsigned char c)
v6_prefill
Forward
void v6_prefill(const float * embed_weight, const int32_t * tokens, int num_tokens, float * logits)
vec_dot_q5_0_q8_0
Forward
void vec_dot_q5_0_q8_0(int n, float * s, const void * vx, const void * vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.
vec_dot_q5_0_q8_0_ref
Forward
void vec_dot_q5_0_q8_0_ref(int n, float * s, const void * vx, const void * vy)
Quantized dot product: Q5_0 weights x Q8_0 input (scalar reference)
vec_dot_q6_k_q8_k
Forward
void vec_dot_q6_k_q8_k(int n, float * s, const void * vx, const void * vy)
Q6_K x Q8_K dot product (single row)
vec_dot_q8_0_q8_0
Forward
void vec_dot_q8_0_q8_0(int n, float * s, const void * vx, const void * vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.
vec_dot_q8_0_q8_0_ref
Forward
void vec_dot_q8_0_q8_0_ref(int n, float * s, const void * vx, const void * vy)
Quantized dot product: Q8_0 weights x Q8_0 input (scalar reference)
vec_scale_parallel
Forward
void vec_scale_parallel(float * y, float scale, int n, int ith, int nth)
vec_zero_parallel
Forward
void vec_zero_parallel(float * y, int n, int ith, int nth)
weighted_sum_f32
Forward
void weighted_sum_f32(float * y, const float ** vectors, const float * weights, int k, int n)
Weighted sum of k vectors: y = sum_i(weights[i] * vectors[i])
worker_main
Forward
void * worker_main(void * arg)
zero_gradients_bf16
Forward
void zero_gradients_bf16(uint16_t * grad, size_t numel)
Zero out gradient buffer (bf16)
zero_gradients_f32
Forward
void zero_gradients_f32(float * grad, size_t numel)
Zero out gradient buffer (fp32)
Usage Example
Forward + Backward Pass
// Forward pass rmsnorm_forward(input, gamma, norm_out, rstd_cache, tokens, d_model, d_model, eps); attention_forward_causal_head_major_gqa(q, k, v, scores, attn_out, heads, kv_heads, tokens, head_dim, head_dim, ctx_len); swiglu_forward(mlp_in, mlp_out, tokens, hidden_dim); // Backward pass (reverse order) swiglu_backward(mlp_in, d_mlp_out, d_mlp_in, tokens, hidden_dim); attention_backward_causal_head_major_gqa(d_attn_out, q, k, v, scores, d_q, d_k, d_v, d_scores, heads, kv_heads, tokens, head_dim, head_dim, ctx_len); rmsnorm_backward(d_norm_out, input, gamma, rstd_cache, d_input, d_gamma, tokens, d_model, d_model);