API Reference

Complete API documentation for C-Kernel-Engine kernels. All functions are exported from libckernel_engine.so and can be called from C or via Python ctypes.

Auto-generated from source
This documentation is extracted from the C header files using Doxygen. Functions marked Forward compute activations, and Backward compute gradients.

Quick Reference

Include Header

#include "ckernel_engine.h"

Link Library

-lckernel_engine

Python ctypes

lib = ctypes.CDLL("libckernel_engine.so")

Memory Layouts

All kernels use consistent memory layouts optimized for cache efficiency:

Buffer Layout Description
input/output [B, T, D] Batch × Tokens × Embedding dimension
Q [H, T, d_k] num_heads × Tokens × head_dim (head-major)
K, V [H_kv, T, d_k] num_kv_heads × Tokens × head_dim (for GQA)
scores [H, T, T] num_heads × query_tokens × key_tokens
weights [out, in] Row-major weight matrices

Kernel Functions

GEMM (Matrix Multiplication) (92)
ck_gemm_add_bias Forward
void ck_gemm_add_bias(float * C, const float * bias, int M, int N)

ck_gemm_nt_head_major_q5_0 Forward
void ck_gemm_nt_head_major_q5_0(const float * attn_out, const void * wo, const float * bias, float * output, int tokens, int embed_dim, int num_heads, int head_dim)

Output projection from head-major attention (auto-dispatch)

ck_gemm_nt_head_major_q8_0 Forward
void ck_gemm_nt_head_major_q8_0(const float * attn_out, const void * wo, const float * bias, float * output, int tokens, int embed_dim, int num_heads, int head_dim)

Output projection from head-major attention (Q8_0 weights)

ck_gemm_nt_quant Forward
void ck_gemm_nt_quant(const float * A, const void * B, const float * bias, float * C, int M, int N, int K, CKDataType dtype)

ck_test_gemm_q4_k Forward
void ck_test_gemm_q4_k(const void * weight_q4k, const float * input_f32, float * output, int rows, int cols, int n_tokens)

Q4_K GEMM - batched matrix multiply with quantized weights.

ck_test_gemm_q5_0 Forward
void ck_test_gemm_q5_0(const void * weight_q5_0, const float * input_f32, float * output, int rows, int cols, int n_tokens)

Test Q5_0 x Q8_0 GEMM (batch matrix multiply)

ck_test_gemm_q5_1 Forward
void ck_test_gemm_q5_1(const void * weight_q5_1, const float * input_f32, float * output, int rows, int cols, int n_tokens)

Test Q5_1 x Q8_0 GEMM (batch matrix multiply)

ck_test_gemm_q5_k Forward
void ck_test_gemm_q5_k(const void * weight_q5_k, const float * input_f32, float * output, int rows, int cols, int n_tokens)

Test Q5_K x Q8_K GEMM (batch matrix multiply)

ck_test_gemm_q6_k Forward
void ck_test_gemm_q6_k(const void * weight_q6k, const float * input_f32, float * output, int rows, int cols, int n_tokens)

Test Q6_K x Q8_K GEMM (batch matrix multiply)

ck_test_gemm_q8_0 Forward
void ck_test_gemm_q8_0(const void * weight_q8_0, const float * input_f32, float * output, int rows, int cols, int n_tokens)

Test Q8_0 x Q8_0 GEMM (batch matrix multiply)

ckernel_sgemm_native Forward
void ckernel_sgemm_native(int M, int N, int K, const float * A, int lda, const float * B, int ldb, const float * bias, float * C, int ldc)

Native GEMM backend that directly reuses the C-Transformer GEMM kernel.

compute_gemm_params Forward
void compute_gemm_params(const CPUInfo * cpu, GEMMParams * params)

fused_rmsnorm_gemm_2d_tiled Forward
void fused_rmsnorm_gemm_2d_tiled(const float * x, const float * gamma, const float * W, float * output, int seq_len, int hidden, int out_dim, float eps, float * x_norm_scratch)

Fused RMSNorm + single GEMM with 2D tiling (weight reuse)

gemm_avx512_parallel Forward
void gemm_avx512_parallel(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_batch_int8_impl_name Forward
const char * gemm_batch_int8_impl_name(void)

Get the best implementation name for logging/debugging.

gemm_bf16_fp32out Forward
void gemm_bf16_fp32out(const uint16_t * A, const uint16_t * B, const float * bias, float * C, int M, int N, int K)

gemm_bias_gelu_fused Forward
void gemm_bias_gelu_fused(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_bias_relu_fused Forward
void gemm_bias_relu_fused(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_bias_silu_fused Forward
void gemm_bias_silu_fused(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_blocked_serial Forward
void gemm_blocked_serial(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_blocked_serial_bf16 Forward
void gemm_blocked_serial_bf16(const uint16_t * A, const uint16_t * B, const uint16_t * bias, uint16_t * C, int M, int N, int K)

gemm_f16 Forward
void gemm_f16(float * Y, const uint16_t * W, const float * X, int M, int N, int K)

Auto-dispatch GEMM based on available SIMD.

gemm_f16_backward Backward
void gemm_f16_backward(float * dX, const uint16_t * W, const float * dY, int M, int N, int K)

Batched backward pass.

gemm_f16_ref Forward
void gemm_f16_ref(float * Y, const uint16_t * W, const float * X, int M, int N, int K)

Matrix-matrix multiply with FP16 weights (scalar reference)

gemm_fine_grained_parallel Forward
void gemm_fine_grained_parallel(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_get_backend
const char * gemm_get_backend(void)

gemm_init_threads Forward
void gemm_init_threads(void)

gemm_microkernel Forward
void gemm_microkernel(const float * A, const float * B, float * C, int M, int N, int K, int B_transposed)

gemm_microkernel_blocked Forward
void gemm_microkernel_blocked(const float * A, const float * B, float * C, int M, int N, int K)

gemm_microkernel_blocked_bt Forward
void gemm_microkernel_blocked_bt(const float * A, const float * B, float * C, int M, int N, int K)

gemm_microkernel_edge Forward
void gemm_microkernel_edge(int m, int n, int K, const float * A, int lda, const float * B, int ldb, float * C, int ldc, int first_k)

gemm_microkernel_packed Forward
void gemm_microkernel_packed(const float * A, const float * B, float * C, int M, int N, int K)

gemm_microkernel_sequential Forward
void gemm_microkernel_sequential(const float * A, const float * B, float * C, int M, int N, int K)

gemm_naive_parallel Forward
void gemm_naive_parallel(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_naive_serial_double Forward
void gemm_naive_serial_double(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_nn_avx512 Forward
void gemm_nn_avx512(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_nn_bf16 Forward
void gemm_nn_bf16(const uint16_t * A, const uint16_t * B, const uint16_t * bias, uint16_t * C, int M, int N, int K)

gemm_nn_blocked Forward
void gemm_nn_blocked(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_nn_parallel Forward
void gemm_nn_parallel(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_nn_serial_double Forward
void gemm_nn_serial_double(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_nt Forward
void gemm_nt(const float * input, const float * weight, float * output, int rows, int cols, int common)

gemm_nt_matvec_parallel Forward
void gemm_nt_matvec_parallel(const float * A, const float * B, const float * bias, float * C, int N, int K)

gemm_nt_q4_0 Forward
void gemm_nt_q4_0(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.

gemm_nt_q4_1 Forward
void gemm_nt_q4_1(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

GEMM with transposed Q4_1 weights: C = A @ B^T.

gemm_nt_q4_k Forward
void gemm_nt_q4_k(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

gemm_nt_q4_k_q8_k Forward
void gemm_nt_q4_k_q8_k(const void * A_q8, const void * B, const float * bias, float * C, int M, int N, int K)

gemm_nt_q5_0 Forward
void gemm_nt_q5_0(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

gemm_nt_q5_0_q8_0 Forward
void gemm_nt_q5_0_q8_0(const void * A_q8, const void * B_q5, const float * bias, float * C, int M, int N, int K)

Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.

gemm_nt_q5_0_q8_0_ref Forward
void gemm_nt_q5_0_q8_0_ref(const void * A, const void * B, float * C, int M, int N, int K)

Dispatcher for gemm_nt_q8_0_q8_0.

gemm_nt_q5_0_q8_0_unroll_avx Forward
void gemm_nt_q5_0_q8_0_unroll_avx(const void * A_q8, const void * B_q5, const float * bias, float * C, int M, int N, int K)

gemm_nt_q5_0_ref Forward
void gemm_nt_q5_0_ref(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

GEMM with transposed Q5_0 weights: C = A @ B^T.

gemm_nt_q5_0_sse Forward
void gemm_nt_q5_0_sse(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

gemm_nt_q5_0_sse_v2 Forward
void gemm_nt_q5_0_sse_v2(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

gemm_nt_q5_1 Forward
void gemm_nt_q5_1(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

GEMM with transposed Q5_1 weights: C = A @ B^T.

gemm_nt_q5_k Forward
void gemm_nt_q5_k(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

gemm_nt_q5_k_ref Forward
void gemm_nt_q5_k_ref(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

gemm_nt_q6_k Forward
void gemm_nt_q6_k(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

gemm_nt_q6_k_q8_k Forward
void gemm_nt_q6_k_q8_k(const void * A_q8, const void * B, const float * bias, float * C, int M, int N, int K)

NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.

gemm_nt_q6_k_ref Forward
void gemm_nt_q6_k_ref(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

gemm_nt_q6_k_sse Forward
void gemm_nt_q6_k_sse(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

gemm_nt_q8_0 Forward
void gemm_nt_q8_0(const float * A, const void * B, const float * bias, float * C, int M, int N, int K)

Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.

gemm_nt_q8_0_dispatch Forward
void gemm_nt_q8_0_dispatch(const void * A_q8, const void * B, const float * bias, float * C, int M, int N, int K, CKDataType dt)

gemm_nt_q8_0_mlp_dispatch Forward
void gemm_nt_q8_0_mlp_dispatch(const void * A_q8, const void * B, const float * bias, float * C, int M, int N, int K, CKDataType dt)

gemm_nt_q8_0_q8_0 Forward
void gemm_nt_q8_0_q8_0(const void * A, const void * B, const float * bias, float * C, int M, int N, int K)

gemm_nt_q8_0_q8_0 with optional bias (matches header signature)

gemm_nt_q8_0_q8_0_ref Forward
void gemm_nt_q8_0_q8_0_ref(const void * A, const void * B, float * C, int M, int N, int K)

Scalar reference: gemm_nt_q8_0_q8_0.

gemm_nt_q8_k_mlp_dispatch Forward
void gemm_nt_q8_k_mlp_dispatch(const void * A_q8, const void * B, const float * bias, float * C, int M, int N, int K, CKDataType dt)

gemm_nt_q8_k_qkv_dispatch Forward
void gemm_nt_q8_k_qkv_dispatch(const void * A_q8k, const void * B, const float * bias, float * C, int M, int N, int K, CKDataType dt)

gemm_q4_0 Forward
void gemm_q4_0(float * Y, const void * W, const float * X, int M, int N, int K)

Matrix-matrix multiply with Q4_0 weights.

gemm_q4_0_backward Backward
void gemm_q4_0_backward(float * dX, const void * W, const float * dY, int M, int N, int K)

Batched backward pass.

gemm_q4_1 Forward
void gemm_q4_1(float * Y, const void * W, const float * X, int M, int N, int K)

Matrix-matrix multiply with Q4_1 weights.

gemm_q4_1_backward Backward
void gemm_q4_1_backward(float * dX, const void * W, const float * dY, int M, int N, int K)

Batched backward pass.

gemm_q4_k Forward
void gemm_q4_k(float * Y, const void * W, const float * X, int M, int N, int K)

Auto-dispatch GEMM based on available SIMD.

gemm_q4_k_backward Backward
void gemm_q4_k_backward(float * dX, const void * W, const float * dY, int M, int N, int K)

Batched backward pass.

gemm_q4_k_q8_k Forward
void gemm_q4_k_q8_k(float * Y, const void * W, const void * X_q8, int M, int N, int K)

gemm_q4_k_q8_k_ref Forward
void gemm_q4_k_q8_k_ref(float * Y, const void * W, const void * X_q8, int M, int N, int K)

gemm_q4_k_ref Forward
void gemm_q4_k_ref(float * Y, const void * W, const float * X, int M, int N, int K)

Matrix-matrix multiply with Q4_K weights (scalar reference)

gemm_q5_0 Forward
void gemm_q5_0(float * Y, const void * W, const float * X, int M, int N, int K)

Matrix-matrix multiply with Q5_0 weights.

gemm_q5_0_backward Backward
void gemm_q5_0_backward(float * dX, const void * W, const float * dY, int M, int N, int K)

Batched backward pass.

gemm_q5_1 Forward
void gemm_q5_1(float * Y, const void * W, const float * X, int M, int N, int K)

Matrix-matrix multiply with Q5_1 weights.

gemm_q5_1_backward Backward
void gemm_q5_1_backward(float * dX, const void * W, const float * dY, int M, int N, int K)

Batched backward pass.

gemm_q6_k Forward
void gemm_q6_k(float * Y, const void * W, const float * X, int M, int N, int K)

gemm_q6_k_q8_k Forward
void gemm_q6_k_q8_k(float * Y, const void * W, const void * X_q8, int M, int N, int K)

GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K.

gemm_q8_0 Forward
void gemm_q8_0(float * Y, const void * W, const float * X, int M, int N, int K)

Matrix-matrix multiply with Q8_0 weights.

gemm_q8_0_backward Backward
void gemm_q8_0_backward(float * dX, const void * W, const float * dY, int M, int N, int K)

Batched backward pass.

gemm_swiglu_fused Forward
void gemm_swiglu_fused(const float * x, const float * W_gate, const float * W_up, const float * b_gate, const float * b_up, float * output, int M, int N, int K)

gemm_tile_nt_strided Forward
void gemm_tile_nt_strided(const float * A, const float * B_tile, float * C, int tile_m, int tile_n, int K, int C_stride)

GEMM tile with N-dimension tiling (weight reuse)

gemm_tn_avx512 Forward
void gemm_tn_avx512(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_tn_bf16 Forward
void gemm_tn_bf16(const uint16_t * A, const uint16_t * B, const uint16_t * bias, uint16_t * C, int M, int N, int K)

gemm_tn_blocked Forward
void gemm_tn_blocked(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_tn_parallel Forward
void gemm_tn_parallel(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

gemm_tn_serial_double Forward
void gemm_tn_serial_double(const float * A, const float * B, const float * bias, float * C, int M, int N, int K)

get_gemm_params Forward
const GEMMParams * get_gemm_params(void)

Layer Normalization (10)
layernorm_backward_kernel Backward
void layernorm_backward_kernel(const float * d_output, const float * input, const float * gamma, const float * mean, const float * rstd, float * d_input, float * d_gamma, float * d_beta, int tokens, int d_model, int aligned_embed_dim)

Backward pass / gradient computation

layernorm_backward_kernel_bf16 Backward
void layernorm_backward_kernel_bf16(const uint16_t * d_output, const uint16_t * input, const float * gamma, const float * mean, const float * rstd, uint16_t * d_input, float * d_gamma, float * d_beta, int tokens, int d_model, int aligned_embed_dim, float * scratch_d_output, float * scratch_input, float * scratch_d_input)

Backward pass / gradient computation

layernorm_forward_rolled_slice Forward
void layernorm_forward_rolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps)

Forward pass computation

layernorm_forward_rolled_slice_bf16 Forward
void layernorm_forward_rolled_slice_bf16(const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps, float * scratch_input, float * scratch_output)

Forward pass computation

layernorm_forward_unrolled_slice Forward
void layernorm_forward_unrolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)

Forward pass computation

layernorm_forward_unrolled_slice_bf16 Forward
void layernorm_forward_unrolled_slice_bf16(const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps, float * scratch_input, float * scratch_output)

Forward pass computation

layernorm_forward_unrolled_slice_scalar Forward
void layernorm_forward_unrolled_slice_scalar(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)

Forward pass computation

layernorm_naive_serial Forward
void layernorm_naive_serial(const float * input, const float * gamma, const float * beta, float * output, float * mean_cache, float * rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)

layernorm_naive_serial_matched_precision Forward
void layernorm_naive_serial_matched_precision(const float * input, const float * gamma, const float * beta, float * output, float * mean_cache, float * rstd_cache, int tokens, int d_model, float eps)

zero_layernorm_padding Forward
void zero_layernorm_padding(float * out_ptr, int d_model, int aligned_embed_dim)

RMS Normalization (42)
ck_layer_backward_rmsnorm_swiglu Backward
void ck_layer_backward_rmsnorm_swiglu(const CKLayerBackwardParams * p)

Backward pass / gradient computation

ck_layer_forward_rmsnorm_swiglu Forward
void ck_layer_forward_rmsnorm_swiglu(const CKLayerForwardParams * p)

Forward pass computation

ck_layer_forward_rmsnorm_swiglu_decode Forward
void ck_layer_forward_rmsnorm_swiglu_decode(const CKLayerForwardParams * p, int token_index, int cache_capacity)

Forward pass computation

ck_layer_forward_rmsnorm_swiglu_decode_fused Forward
void ck_layer_forward_rmsnorm_swiglu_decode_fused(const CKLayerForwardParams * p, int token_index, int cache_capacity)

Forward pass computation

ck_layer_forward_rmsnorm_swiglu_decode_fused_attn Forward
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn(const CKLayerForwardParams * p, int token_index, int cache_capacity)

Forward pass computation

ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl Forward
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(const CKLayerForwardParams * p, int token_index, int cache_capacity, int fuse_mlp)

Forward pass computation

ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp Forward
void ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_mlp(const CKLayerForwardParams * p, int token_index, int cache_capacity)

Forward pass computation

ck_layer_forward_rmsnorm_swiglu_decode_q4_k Forward
void ck_layer_forward_rmsnorm_swiglu_decode_q4_k(const CKLayerForwardParamsQ4K * p, int token_index, int cache_capacity)

Forward pass computation

ck_layer_forward_rmsnorm_swiglu_decode_quant Forward
void ck_layer_forward_rmsnorm_swiglu_decode_quant(const CKLayerForwardParamsQ4K * p, int token_index, int cache_capacity)

Forward pass computation

ck_layer_forward_rmsnorm_swiglu_q4_k Forward
void ck_layer_forward_rmsnorm_swiglu_q4_k(const CKLayerForwardParamsQ4K * p)

Forward pass computation

ck_layer_forward_rmsnorm_swiglu_quant Forward
void ck_layer_forward_rmsnorm_swiglu_quant(const CKLayerForwardParamsQ4K * p)

Forward pass computation

ck_layer_forward_rmsnorm_swiglu_ref Forward
void ck_layer_forward_rmsnorm_swiglu_ref(const CKLayerForwardParams * p)

Forward pass computation

ck_test_rmsnorm Forward
void ck_test_rmsnorm(const float * input, const float * weight, float * output, int n_tokens, int dim, float eps)

RMSNorm.

fused_rmsnorm Forward
void fused_rmsnorm(const float * input, const float * gamma, const float * beta, float * output, int hidden, float eps)

Fused RMSNorm - writes to pre-allocated buffer.

fused_rmsnorm_linear_q4k Forward
void fused_rmsnorm_linear_q4k(float * y, const float * x, const float * gamma, const void * W_q4k, int M, int K, float eps)

Fused RMSNorm + Q4_K Linear projection.

fused_rmsnorm_qkv Forward
void fused_rmsnorm_qkv(const float * input, const float * gamma, const float * W_qkv, const float * b_qkv, float * q_out, float * k_out, float * v_out, int hidden, int num_heads, int num_kv_heads, int head_dim, float eps)

Fused RMSNorm with fused QKV projection.

fused_rmsnorm_qkv_prefill Forward
void fused_rmsnorm_qkv_prefill(const float * x, const float * gamma, const float * Wq, const float * Wk, const float * Wv, float * Q, float * K, float * V, int seq_len, int hidden, int q_dim, int kv_dim, float eps, float * scratch)

Fused RMSNorm + QKV projection for prefill (v3 optimized)

fused_rmsnorm_qkv_prefill_head_major Forward
void fused_rmsnorm_qkv_prefill_head_major(const float * x, const float * gamma, const float * Wq, const float * Bq, const float * Wk, const float * Bk, const float * Wv, const float * Bv, float * Q, float * K, float * V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float * scratch)

Fused RMSNorm + QKV projection for prefill (head-major outputs)

fused_rmsnorm_qkv_prefill_head_major_quant Forward
void fused_rmsnorm_qkv_prefill_head_major_quant(const float * x, const float * gamma, const void * Wq, const float * Bq, CKDataType wq_dt, const void * Wk, const float * Bk, CKDataType wk_dt, const void * Wv, const float * Bv, CKDataType wv_dt, float * Q, float * K, float * V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void * scratch)

Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)

fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size Forward
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)

Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.

fused_rmsnorm_qkv_scratch_size Forward
size_t fused_rmsnorm_qkv_scratch_size(int hidden)

Get scratch size for fused prefill.

mega_fuse_rmsnorm_qkv Forward
void mega_fuse_rmsnorm_qkv(float * q_out, float * k_out, float * v_out, const float * input, const float * gamma, const float * W_qkv, const float * b_qkv, int hidden, int num_heads, int num_kv_heads, int head_dim, float eps)

Phase 1: Fused RMSNorm + QKV (intermediates in registers)

mega_fuse_rmsnorm_qkv_avx Forward
void mega_fuse_rmsnorm_qkv_avx(float * q_out, float * k_out, float * v_out, const float * input, const float * gamma, const float * wq, const float * bq, const float * wk, const float * bk, const float * wv, const float * bv, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps)

Fused RMSNorm + QKV for decode (single token)

mega_fuse_rmsnorm_qkv_rope Forward
void mega_fuse_rmsnorm_qkv_rope(float * q_out, float * k_out, float * v_out, const float * input, const float * gamma, const float * W_qkv, const float * b_qkv, const float * rope_cos, const float * rope_sin, int pos, int hidden, int num_heads, int num_kv_heads, int head_dim, int max_seq, float eps)

Phase 2: Fused RMSNorm + QKV + RoPE.

rmsnorm_backward Backward
void rmsnorm_backward(const float * d_output, const float * input, const float * gamma, const float * rstd_cache, float * d_input, float * d_gamma, int tokens, int d_model, int aligned_embed_dim)

RMSNorm backward pass Testtest_rmsnorm.py::TestRMSNormBackward::test_backward_tokens test_rmsnorm.py::TestRMSNormBackward::test_backward_single test_parity.py::test_rmsnorm_backward_parity

rmsnorm_backward_bf16 Backward
void rmsnorm_backward_bf16(const uint16_t * d_output, const uint16_t * input, const float * gamma, const float * rstd_cache, uint16_t * d_input, float * d_gamma, int tokens, int d_model, int aligned_embed_dim)

Backward pass / gradient computation

rmsnorm_backward_int4 Backward
void rmsnorm_backward_int4(const uint8_t * d_output, const uint8_t * input, const float * gamma, const float * rstd_cache, uint8_t * d_input, float * d_gamma, int tokens, int d_model, int aligned_embed_dim, float * scratch_d_output, float * scratch_input, float * scratch_d_input)

Backward pass / gradient computation

rmsnorm_backward_int8 Backward
void rmsnorm_backward_int8(const int8_t * d_output, const int8_t * input, const float * gamma, const float * rstd_cache, int8_t * d_input, float * d_gamma, int tokens, int d_model, int aligned_embed_dim, float * scratch_d_output, float * scratch_input, float * scratch_d_input)

Backward pass / gradient computation

rmsnorm_forward Forward
void rmsnorm_forward(const float * input, const float * gamma, float * output, float * rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)

RMSNorm forward pass Testtest_rmsnorm.py::TestRMSNormForward::test_fp32_tokens test_rmsnorm.py::TestRMSNormForward::test_fp32_single test_rmsnorm.py::TestRMSNormForward::test_perf_rolled test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat test_parity.py::test_rmsnorm_parity

rmsnorm_forward_bf16 Forward
void rmsnorm_forward_bf16(const uint16_t * input, const float * gamma, uint16_t * output, float * rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)

Forward pass computation

rmsnorm_forward_int4 Forward
void rmsnorm_forward_int4(const uint8_t * input, const float * gamma, uint8_t * output, float * rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float * scratch_input, float * scratch_output)

Forward pass computation

rmsnorm_forward_int8 Forward
void rmsnorm_forward_int8(const int8_t * input, const float * gamma, int8_t * output, float * rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float * scratch_input, float * scratch_output)

Forward pass computation

rmsnorm_q8_k_fused Forward
void rmsnorm_q8_k_fused(const float * input, const float * gamma, void * vy, int tokens, int d_model, int aligned_embed_dim, float eps)

Fused RMSNorm + Q8_K Quantization

rmsnorm_qkv_fp32_fused Forward
void rmsnorm_qkv_fp32_fused(const float * x, const float * rms_weight, const float * wq, const float * wk, const float * wv, float * q_out, float * k_out, float * v_out, int embed_dim, int q_dim, int kv_dim, float eps)

rmsnorm_qkv_fp32_fused_v2 Forward
void rmsnorm_qkv_fp32_fused_v2(const float * x, const float * rms_weight, const float * wq, const float * wk, const float * wv, float * q_out, float * k_out, float * v_out, int embed_dim, int q_dim, int kv_dim, float eps)

rmsnorm_qkv_fp32_fused_v3 Forward
void rmsnorm_qkv_fp32_fused_v3(const float * x, const float * rms_weight, const float * wq, const float * wk, const float * wv, float * q_out, float * k_out, float * v_out, int embed_dim, int q_dim, int kv_dim, float eps)

rmsnorm_qkv_q4k_fused Forward
void rmsnorm_qkv_q4k_fused(const float * x, const float * rms_weight, const void * wq, const void * wk, const void * wv, float * q_out, float * k_out, float * v_out, int embed_dim, int q_dim, int kv_dim, float eps)

rmsnorm_qkv_separate_fp32 Forward
void rmsnorm_qkv_separate_fp32(const float * x, const float * rms_weight, const float * wq, const float * wk, const float * wv, float * normed, float * q_out, float * k_out, float * v_out, int embed_dim, int q_dim, int kv_dim, float eps)

rmsnorm_tile Forward
void rmsnorm_tile(const float * input, const float * gamma, float * output, int tile_m, int embed_dim, int aligned_embed_dim, float eps)

Compute RMSNorm for a tile of tokens.

simple_rmsnorm Forward
void simple_rmsnorm(const float * input, const float * gamma, float * output, int tokens, int d_model, float eps)

unfused_rmsnorm_linear_q4k_ref Forward
void unfused_rmsnorm_linear_q4k_ref(float * y, const float * x, const float * gamma, const void * W_q4k, int M, int K, float eps)

Reference (unfused) implementation for correctness testing.

unfused_rmsnorm_qkv_prefill Forward
void unfused_rmsnorm_qkv_prefill(const float * x, const float * gamma, const float * Wq, const float * Wk, const float * Wv, float * x_norm, float * Q, float * K, float * V, int seq_len, int hidden, int q_dim, int kv_dim, float eps)

Unfused version for comparison.

GELU Activation (10)
fast_gelu_scalar Forward
float fast_gelu_scalar(float x)

gelu_backward_exact Backward
void gelu_backward_exact(const float * input, const float * d_output, float * d_input, size_t n)

Backward pass / gradient computation

gelu_backward_exact_bf16 Backward
void gelu_backward_exact_bf16(const uint16_t * input, const uint16_t * d_output, uint16_t * d_input, size_t n, float * scratch_input, float * scratch_d_output, float * scratch_d_input)

Backward pass / gradient computation

gelu_backward_fast Backward
void gelu_backward_fast(const float * input, const float * d_output, float * d_input, size_t n)

Backward pass / gradient computation

gelu_backward_fast_bf16 Backward
void gelu_backward_fast_bf16(const uint16_t * input, const uint16_t * d_output, uint16_t * d_input, size_t n, float * scratch_input, float * scratch_d_output, float * scratch_d_input)

Backward pass / gradient computation

gelu_backward_scalar Backward
void gelu_backward_scalar(const float * input, const float * d_output, float * d_input, size_t n)

Backward pass / gradient computation

gelu_exact_inplace Forward
void gelu_exact_inplace(float * data, size_t n)

gelu_fast_inplace Forward
void gelu_fast_inplace(float * data, size_t n)

GELU activation forward (fast approximation, in-place) Testtest_gelu.py::TestGELUForward::test_gelu_fast_inplace test_gelu.py::TestGELUForward::test_gelu_vs_exact test_parity.py::test_gelu_parity

gelu_fast_inplace_bf16 Forward
void gelu_fast_inplace_bf16(uint16_t * data, size_t n, float * scratch)

gelu_scalar Forward
float gelu_scalar(float x)

Softmax (11)
backward_causal_softmax_head_major Backward
void backward_causal_softmax_head_major(float * d_scores, const float * weights, int num_heads, int num_tokens, int aligned_context_window)

Backward pass / gradient computation

backward_causal_softmax_head_major_bf16 Backward
void backward_causal_softmax_head_major_bf16(uint16_t * d_scores, const uint16_t * weights, int num_heads, int num_tokens, int aligned_context_window, float * scratch_d_scores, float * scratch_weights)

Backward pass / gradient computation

causal_softmax_head_major Forward
void causal_softmax_head_major(float * scores, int num_heads, int num_tokens, int aligned_context_window)

Causal softmax (in-place, row-wise) Testtest_softmax.py::TestSoftmaxForward::test_causal_softmax test_softmax.py::TestSoftmaxForward::test_causal_vs_softmax test_attention.py::TestAttentionForward::test_softmax_correctness

causal_softmax_head_major_bf16 Forward
void causal_softmax_head_major_bf16(uint16_t * scores, int num_heads, int num_tokens, int aligned_context_window, float * scratch)

causal_softmax_head_major_exact Forward
void causal_softmax_head_major_exact(float * scores, int num_heads, int num_tokens, int aligned_context_window)

Causal softmax (exact version using stdlib expf) Testtest_softmax.py::TestSoftmaxForward::test_causal_softmax_exact test_softmax.py::TestSoftmaxForward::test_exact_vs_fast

ck_test_softmax Forward
void ck_test_softmax(const float * input, float * output, int n)

Softmax (simple, non-causal)

softmax Forward
void softmax(float * x, int n)

softmax_cross_entropy_loss Forward
void softmax_cross_entropy_loss(const float * logits, const int32_t * targets, int tokens, int vocab_size, float * d_logits, float * loss_out)

softmax_cross_entropy_loss_bf16 Forward
void softmax_cross_entropy_loss_bf16(const uint16_t * logits, const int32_t * targets, int tokens, int vocab_size, uint16_t * d_logits, float * loss_out, float * scratch_logits, float * scratch_d_logits)

softmax_inplace Forward
void softmax_inplace(float * x, int n)

topk_softmax_f32 Forward
void topk_softmax_f32(const float * scores, int n, int k, int * indices, float * weights)

Find top-K indices with softmax-normalized weights.

Attention (48)
attention_backward_causal_head_major Backward
void attention_backward_causal_head_major(const float * d_output, const float * q, const float * k, const float * v, const float * attn_weights, float * d_q, float * d_k, float * d_v, float * d_scores, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)

Causal attention backward (non-GQA version) Testtest_attention_backward.py::TestAttentionBackward::test_backward test_attention_backward.py::TestAttentionBackward::test_backward_vs_separate test_parity.py::test_attention_backward_parity

attention_backward_causal_head_major_gqa Backward
void attention_backward_causal_head_major_gqa(const float * d_output, const float * q, const float * k, const float * v, const float * attn_weights, float * d_q, float * d_k, float * d_v, float * d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)

GQA causal attention backward (score-matrix version) Testtest_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_vs_separate test_parity.py::test_attention_backward_parity

attention_backward_causal_head_major_gqa_bf16 Backward
void attention_backward_causal_head_major_gqa_bf16(const uint16_t * d_output, float * d_x, const uint16_t * q, const uint16_t * k, const uint16_t * v, const float * attn_weights, float * d_q, float * d_k, float * d_v, float * d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float * scratch_d_output, float * scratch_q, float * scratch_k, float * scratch_v)

BF16 attention backward with caller-provided scratch buffers Testbf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_backward

attention_flash_cleanup Forward
void attention_flash_cleanup(void)

Clean up flash attention resources.

attention_flash_decode Forward
void attention_flash_decode(float * out, const float * q, const float * k, const float * v, int T_q, int T_k, int H, int D_h, float scale)

Main flash attention function with SIMD dispatch.

attention_flash_decode_scalar Forward
void attention_flash_decode_scalar(float * out, const float * q, const float * k, const float * v, int T_q, int T_k, int H, int D_h, float scale)

Scalar flash-style attention (online softmax)

attention_flash_init Forward
void attention_flash_init(int max_context, int max_heads, int max_head_dim)

Initialize flash attention buffers.

attention_flash_query_causal Forward
void attention_flash_query_causal(const float * q_vec, const float * k_head, const float * v_head, int kv_tokens, int head_dim, int aligned_head_dim, float scale, float * out_vec)

attention_flash_query_sliding Forward
void attention_flash_query_sliding(const float * q_vec, const float * k_head, const float * v_head, int query_pos, int kv_tokens, int head_dim, int aligned_head_dim, float scale, float * out_vec, int sliding_window)

attention_forward_causal_head_major Forward
void attention_forward_causal_head_major(const float * q, const float * k, const float * v, float * scores, float * output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)

Causal attention forward (score-matrix version) Testtest_attention.py::TestAttentionForward::test_causal_forward test_attention.py::TestAttentionForward::test_gqa_broadcast test_attention.py::TestAttentionForward::test_exact_vs_fast test_parity.py::test_attention_parity

attention_forward_causal_head_major_exact Forward
void attention_forward_causal_head_major_exact(const float * q, const float * k, const float * v, float * scores, float * output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)

Causal attention forward (exact version using stdlib expf) Testtest_attention.py::TestAttentionForward::test_exact_single test_attention.py::TestAttentionForward::test_exact_vs_fast

attention_forward_causal_head_major_gqa Forward
void attention_forward_causal_head_major_gqa(const float * q, const float * k, const float * v, float * scores, float * output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)

GQA causal attention forward (score-matrix version) Testtest_attention.py::TestAttentionForward::test_gqa_forward test_attention.py::TestAttentionForward::test_gqa_broadcast test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward test_parity.py::test_attention_gqa_parity

attention_forward_causal_head_major_gqa_bf16 Forward
void attention_forward_causal_head_major_gqa_bf16(const uint16_t * q, const uint16_t * k, const uint16_t * v, float * scores, float * output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float * scratch_q, float * scratch_k, float * scratch_v)

BF16 GQA causal attention forward Testbf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_forward bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_flash

attention_forward_causal_head_major_gqa_exact Forward
void attention_forward_causal_head_major_gqa_exact(const float * q, const float * k, const float * v, float * scores, float * output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)

GQA causal attention forward (exact version using stdlib expf) Testtest_attention.py::TestAttentionForward::test_gqa_exact bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa

attention_forward_causal_head_major_gqa_flash Forward
void attention_forward_causal_head_major_gqa_flash(const float * q, const float * k, const float * v, float * output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)

Flash attention forward for GQA (prefill, no score materialization) Testtest_flash_attention.py::TestFlashAttention::test_flash_forward test_flash_attention.py::TestFlashAttention::test_flash_vs_score_matrix test_flash_attention.py::TestFlashAttention::test_flash_gqa test_attention.py::TestAttentionForward::test_flash_forward

attention_forward_causal_head_major_gqa_flash_strided Forward
void attention_forward_causal_head_major_gqa_flash_strided(const float * q, const float * k, const float * v, float * output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)

Flash attention forward with custom KV stride (for KV cache) Testtest_flash_attention.py::TestFlashAttention::test_flash_strided test_kv_cache_attention.py::TestKVCacheAttention::test_flash_attention

attention_forward_causal_head_major_gqa_flash_strided_sliding Forward
void attention_forward_causal_head_major_gqa_flash_strided_sliding(const float * q, const float * k, const float * v, float * output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window)

Flash attention forward with sliding window (prefill) Testtest_attention.py::TestAttentionForward::test_sliding_window_prefill

attention_forward_decode_head_major_gqa_flash Forward
void attention_forward_decode_head_major_gqa_flash(const float * q_token, const float * k_cache, const float * v_cache, float * out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)

Flash attention decode (single token attends to KV cache) Testtest_flash_attention.py::TestFlashAttention::test_flash_decode test_kv_cache_attention.py::TestKVCacheAttention::test_flash_decode test_fused_attention_decode.py::TestFusedAttentionDecode::test_flash_decode test_attention.py::TestAttentionForward::test_flash_decode

attention_forward_decode_head_major_gqa_flash_sliding Forward
void attention_forward_decode_head_major_gqa_flash_sliding(const float * q_token, const float * k_cache, const float * v_cache, float * out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window)

Flash attention decode with sliding window Testtest_attention.py::TestAttentionForward::test_sliding_window_decode

attention_forward_decode_head_major_gqa_regular Forward
void attention_forward_decode_head_major_gqa_regular(const float * q_token, const float * k_cache, const float * v_cache, float * out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)

WARNING: This is NOT true flash attention!

attention_mlp_fused_fp32 Forward
void attention_mlp_fused_fp32(const float * q, const float * k_cache, const float * v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const float * wo, const float * residual_1, const float * rms_weight, float eps, const float * w_gate, const float * w_up, const float * w_down, int embed_dim, int intermediate_dim, float * hidden_out)

attention_mlp_fused_q4k Forward
void attention_mlp_fused_q4k(const float * q, const float * k_cache, const float * v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const void * wo, const float * residual_1, const float * rms_weight, float eps, const void * w_gate, const void * w_up, const void * w_down, int embed_dim, int intermediate_dim, float * hidden_out)

attention_mlp_separate_fp32 Forward
void attention_mlp_separate_fp32(const float * q, const float * k_cache, const float * v_cache, int seq_len, int num_heads, int num_kv_heads, int head_dim, float attn_scale, const float * wo, const float * residual_1, const float * rms_weight, float eps, const float * w_gate, const float * w_up, const float * w_down, int embed_dim, int intermediate_dim, float * attn_out_buf, float * hidden_after_attn_buf, float * normed_buf, float * gate_buf, float * up_buf, float * mlp_out_buf, float * hidden_out)

ck_attention_flash_decode_wrapper Forward
void ck_attention_flash_decode_wrapper(const float * q_token, const float * k_cache, const float * v_cache, float * out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)

Wrapper to call TRUE flash attention from orchestration layer.

ck_attention_project_head_major Forward
void ck_attention_project_head_major(const float * attn_out, const float * wo, const float * bo, float * out, float * scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)

ck_attention_project_head_major_backward Backward
void ck_attention_project_head_major_backward(const float * d_out, const float * attn_out, const float * wo, float * d_attn_out, float * d_wo, float * d_bo, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)

Backward pass / gradient computation

ck_attention_project_head_major_decode_token Forward
void ck_attention_project_head_major_decode_token(const float * attn_token, const float * wo, const float * bo, float * out_token, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)

ck_attention_project_head_major_decode_token_residual Forward
void ck_attention_project_head_major_decode_token_residual(const float * attn_token, const float * wo, const float * bo, const float * residual_in, float * proj_out, float * residual_out, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim)

ck_attention_project_head_major_q4_k Forward
void ck_attention_project_head_major_q4_k(const float * attn_out, const void * wo, const float * bo, float * out, float * scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)

ck_attention_project_head_major_q4_k_q8_k Forward
void ck_attention_project_head_major_q4_k_q8_k(const float * attn_out, const void * wo, const float * bo, float * out, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)

ck_attention_project_head_major_quant Forward
void ck_attention_project_head_major_quant(const float * attn_out, const void * wo, const float * bo, float * out, float * scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, CKDataType wo_dtype)

ck_attention_project_head_major_ref Forward
void ck_attention_project_head_major_ref(const float * attn_out, const float * wo, const float * bo, float * out, float * scratch, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)

ck_test_attention_causal Forward
void ck_test_attention_causal(const float * q, const float * k, const float * v, float * out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim)

Multi-head causal attention for prefill (head-major layout)

ck_test_attention_decode_sliding Forward
void ck_test_attention_decode_sliding(const float * q_token, const float * k_cache, const float * v_cache, float * out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int sliding_window)

Test sliding-window attention (decode mode)

ck_test_attention_sliding_window Forward
void ck_test_attention_sliding_window(const float * q, const float * k, const float * v, float * out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim, int sliding_window)

Test sliding-window attention (prefill)

fused_flash_attention_all_heads Forward
void fused_flash_attention_all_heads(float * o_out, const float * q_all, const float * kv_cache_k, const float * kv_cache_v, int num_heads, int num_kv_heads, int head_dim, int seq_len, int kv_tile_size)

Fused Flash Attention for all heads (parallel dispatch)

fused_flash_attention_head Forward
void fused_flash_attention_head(float * o_out, const float * q, const float * kv_cache_k, const float * kv_cache_v, int kv_head_idx, int seq_len, int head_dim, int kv_tile_size)

Fused Flash Attention for single head.

mega_fuse_flash_attention_avx Forward
void mega_fuse_flash_attention_avx(float * o_out, const float * q, const float * kv_cache_k, const float * kv_cache_v, int num_heads, int num_kv_heads, int seq_len, int cache_capacity, int head_dim, int aligned_head_dim)

Flash attention with online softmax (AVX version)

mega_fused_attention Forward
void mega_fused_attention(float * output, const float * input, const float * residual, const float * W_qkv, const float * b_qkv, const float * W_o, const float * b_o, float * kv_cache_k, float * kv_cache_v, const float * rope_cos, const float * rope_sin, int pos, int seq_len, int hidden, int num_heads, int num_kv_heads, int head_dim, int max_seq, float eps)

Complete mega-fused attention block.

mega_fused_attention_decode Forward
void mega_fused_attention_decode(float * output, const float * input, const float * residual, const float * ln1_gamma, const float * wq, const float * bq, const float * wk, const float * bk, const float * wv, const float * bv, const float * wo, const float * bo, float * kv_cache_k, float * kv_cache_v, const float * rope_cos, const float * rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps)

Mega-fused attention for decode mode (single token)

mega_fused_attention_decode_q5_0 Forward
void mega_fused_attention_decode_q5_0(float * output, const float * input, const float * residual, const void * wq_q5_0, const void * wk_q5_0, const void * wv_q8_0, const void * wo_q5_0, const float * ln_gamma, const float * bq, const float * bk, const float * bv, const float * bo, float * kv_cache_k, float * kv_cache_v, const float * rope_cos, const float * rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void * scratch)

Serial mega-fused attention decode kernel.

mega_fused_attention_decode_q5_0_parallel_simd Forward
void mega_fused_attention_decode_q5_0_parallel_simd(float * output, const float * input, const float * residual, const void * wq_q5_0, const void * wk_q5_0, const void * wv_q8_0, const void * wo_q5_0, const float * ln_gamma, const float * bq, const float * bk, const float * bv, const float * bo, float * kv_cache_k, float * kv_cache_v, const float * rope_cos, const float * rope_sin, int pos, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int cache_capacity, float eps, void * scratch, int ith, int nth)

Parallel SIMD mega-fused attention decode kernel (threadpool-aware)

mega_fused_attention_decode_scratch_size Forward
int mega_fused_attention_decode_scratch_size(int AE, int H, int KV, int AD)

Calculate scratch buffer size needed for the kernel.

mega_fused_attention_prefill Forward
void mega_fused_attention_prefill(float * output, const float * input, const float * residual, const float * ln1_gamma, const void * wq, const float * bq, CKDataType wq_dt, const void * wk, const float * bk, CKDataType wk_dt, const void * wv, const float * bv, CKDataType wv_dt, const void * wo, const float * bo, CKDataType wo_dt, float * kv_cache_k, float * kv_cache_v, const float * rope_cos, const float * rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void * scratch)

Mega-fused attention for prefill mode (multiple tokens)

mega_fused_attention_prefill_q8_0 Forward
void mega_fused_attention_prefill_q8_0(float * output, const float * input, const float * residual, const float * ln1_gamma, const void * wq, const float * bq, CKDataType wq_dt, const void * wk, const float * bk, CKDataType wk_dt, const void * wv, const float * bv, CKDataType wv_dt, const void * wo, const float * bo, CKDataType wo_dt, float * kv_cache_k, float * kv_cache_v, const float * rope_cos, const float * rope_sin, int start_pos, int tokens, int cache_capacity, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, float eps, void * scratch)

Mega-fused prefill attention kernel (Q8_0 out-proj)

mega_fused_attention_prefill_q8_0_scratch_size Forward
size_t mega_fused_attention_prefill_q8_0_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)

Get scratch buffer size for mega_fused_attention_prefill_q8_0.

mega_fused_attention_prefill_scratch_size Forward
size_t mega_fused_attention_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)

Get scratch buffer size for mega_fused_attention_prefill.

simple_attention Forward
void simple_attention(const float * q, const float * k, const float * v, float * output, int num_heads, int num_kv_heads, int seq_len, int head_dim)

MLP / Feed-Forward (30)
ck_mlp_swiglu_forward Forward
void ck_mlp_swiglu_forward(const float * input, const float * w1, const float * b1, const float * w2, const float * b2, float * fc1_out, float * swiglu_out, float * output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)

Forward pass computation

ck_mlp_swiglu_forward_fully_fused_token Forward
void ck_mlp_swiglu_forward_fully_fused_token(const float * input_row, const float * w1, const float * b1, const float * w2, const float * b2, float * output_row, int aligned_embed_dim, int aligned_intermediate_dim)

Forward pass computation

ck_mlp_swiglu_forward_fused_token Forward
void ck_mlp_swiglu_forward_fused_token(const float * input_row, const float * w1, const float * b1, const float * w2, const float * b2, float * swiglu_row, float * output_row, int aligned_embed_dim, int aligned_intermediate_dim)

Forward pass computation

ck_mlp_swiglu_forward_q4_k Forward
void ck_mlp_swiglu_forward_q4_k(const float * input, const void * w1, const float * b1, const void * w2, const float * b2, float * fc1_out, float * swiglu_out, float * output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)

Forward pass computation

ck_mlp_swiglu_forward_q4_k_q8_k Forward
void ck_mlp_swiglu_forward_q4_k_q8_k(const float * input, const void * w1, const float * b1, const void * w2, const float * b2, float * fc1_out, float * swiglu_out, float * output, int aligned_embed_dim, int aligned_intermediate_dim)

Forward pass computation

ck_mlp_swiglu_forward_q4_k_q8_k_prefill Forward
void ck_mlp_swiglu_forward_q4_k_q8_k_prefill(const float * input, const void * w1, const float * b1, const void * w2, const float * b2, float * fc1_out, float * swiglu_out, float * output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)

Forward pass computation

ck_mlp_swiglu_forward_quant Forward
void ck_mlp_swiglu_forward_quant(const float * input, const void * w1, const float * b1, CKDataType w1_dtype, const void * w2, const float * b2, CKDataType w2_dtype, float * fc1_out, float * swiglu_out, float * output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)

Forward pass computation

ck_mlp_swiglu_forward_ref Forward
void ck_mlp_swiglu_forward_ref(const float * input, const float * w1, const float * b1, const float * w2, const float * b2, float * fc1_out, float * swiglu_out, float * output, int tokens, int aligned_embed_dim, int aligned_intermediate_dim)

Forward pass computation

ck_test_outproj_mlp_fused_q5_0 Forward
void ck_test_outproj_mlp_fused_q5_0(const float * attn_out, const float * residual, const float * ln2_gamma, const void * wo, const void * w1, const void * w2, float * output, int tokens, int num_heads, int head_dim, int embed_dim, int intermediate, float eps, int w2_is_q6k)

Test mega-fused OutProj + MLP kernel (Q5_0 weights)

fused_mlp_swiglu_decode Forward
void fused_mlp_swiglu_decode(const float * x, const float * W_gate, const float * W_up, const float * W_down, const float * b_gate, const float * b_up, const float * b_down, float * output, int D, int Hff)

fused_mlp_swiglu_decode_tiled Forward
void fused_mlp_swiglu_decode_tiled(const float * x, const float * W_gate, const float * W_up, const float * W_down, const float * b_gate, const float * b_up, const float * b_down, float * output, int D, int Hff)

fused_mlp_swiglu_decode_v2 Forward
void fused_mlp_swiglu_decode_v2(const float * x, const float * W_gate, const float * W_up, const float * W_down, const float * b_gate, const float * b_up, const float * b_down, float * output, int D, int Hff)

fused_mlp_swiglu_prefill Forward
void fused_mlp_swiglu_prefill(const float * x, const float * W_gate, const float * W_up, const float * W_down, float * output, int seq_len, int hidden, int intermediate, float * scratch)

Fused MLP (Gate + Up + SwiGLU + Down) for prefill.

fused_mlp_swiglu_prefill_bias Forward
void fused_mlp_swiglu_prefill_bias(const float * x, const float * W_gate, const float * W_up, const float * W_down, const float * B_gate, const float * B_up, const float * B_down, float * output, int seq_len, int hidden, int intermediate, float * scratch)

Fused MLP for prefill with proper tiling.

fused_mlp_swiglu_prefill_w1w2_quant Forward
void fused_mlp_swiglu_prefill_w1w2_quant(const float * x, const void * W1, const float * B1, CKDataType w1_dt, const void * W2, const float * B2, CKDataType w2_dt, float * output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void * scratch)

Quantized fused MLP for prefill (W1=gate+up, W2=down)

fused_mlp_swiglu_prefill_w1w2_quant_scratch_size Forward
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(int aligned_embed_dim, int aligned_intermediate_dim)

Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.

fused_mlp_swiglu_scratch_size Forward
size_t fused_mlp_swiglu_scratch_size(int intermediate)

Get scratch size for fused MLP.

layer_fused_attn_mlp_qkv_q4k Forward
void layer_fused_attn_mlp_qkv_q4k(const float * q, const float * k_cache, const float * v_cache, int seq_len, float attn_scale, const void * wo, const float * rms_weight_mlp, const void * w_gate, const void * w_up, const void * w_down, const float * rms_weight_attn, const void * wq_next, const void * wk_next, const void * wv_next, const float * residual_in, int embed_dim, int intermediate_dim, int num_heads, int num_kv_heads, int head_dim, float eps, float * q_next, float * k_next, float * v_next, float * hidden_out)

mega_fused_outproj_mlp_prefill Forward
void mega_fused_outproj_mlp_prefill(float * output, const float * attn_out, const float * residual, const float * ln2_gamma, const void * wo, const float * bo, CKDataType wo_dt, const void * w1, const float * b1, CKDataType w1_dt, const void * w2, const float * b2, CKDataType w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void * scratch)

Mega-fused post-attention block (out-proj + RMSNorm2 + MLP) for prefill.

mega_fused_outproj_mlp_prefill_scratch_size Forward
size_t mega_fused_outproj_mlp_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim)

Get scratch buffer size for mega_fused_outproj_mlp_prefill.

mlp_fused_fp32_v2 Forward
void mlp_fused_fp32_v2(const float * hidden_in, const float * rms_weight, float eps, const float * w_gate, const float * w_up, const float * w_down, int embed_dim, int intermediate_dim, float * hidden_out)

mlp_fused_fp32_v3 Forward
void mlp_fused_fp32_v3(const float * hidden_in, const float * rms_weight, float eps, const float * w_gate, const float * w_up, const float * w_down, int embed_dim, int intermediate_dim, float * hidden_out)

mlp_parallel Forward
void mlp_parallel(const void * ln2_q8, const void * W_gate, const void * W_up, const void * W_down, float * gate_buf, float * up_buf, float * swiglu_buf, void * down_q8, float * mlp_out, int intermediate, int embed_dim, int num_threads)

Parallel MLP (gate/up + SwiGLU + down projection).

mlp_q8_0_dtype_supported Forward
int mlp_q8_0_dtype_supported(CKDataType dt)

mlp_q8_k_dtype_supported Forward
int mlp_q8_k_dtype_supported(CKDataType dt)

mlp_separate_fp32 Forward
void mlp_separate_fp32(const float * hidden_in, const float * rms_weight, float eps, const float * w_gate, const float * w_up, const float * w_down, float * normed_buf, float * gate_buf, float * up_buf, int embed_dim, int intermediate_dim, float * hidden_out)

mlp_token_parallel Forward
void mlp_token_parallel(const float * input, const float * W_fc1, const float * b_fc1, const float * W_fc2, const float * b_fc2, float * fc1_output, float * output, int T, int aligned_dim, int num_threads)

mlp_token_parallel_bf16 Forward
void mlp_token_parallel_bf16(const uint16_t * input, const uint16_t * W_fc1, const uint16_t * b_fc1, const uint16_t * W_fc2, const uint16_t * b_fc2, float * fc1_output, float * output, int T, int aligned_dim, int num_threads, float * scratch_bias1_f, float * scratch_bias2_f, uint16_t * scratch_fc1_bf16)

Optimized MLP Forward (BF16 weights, FP32 activations)

mlp_token_parallel_bf16_fp32act Forward
void mlp_token_parallel_bf16_fp32act(const uint16_t * input, const uint16_t * W_fc1, const uint16_t * b_fc1, const uint16_t * W_fc2, const uint16_t * b_fc2, float * fc1_output, float * output, int T, int aligned_dim, int num_threads, float * scratch_input_f, float * scratch_bias1_f, float * scratch_bias2_f, uint16_t * scratch_fc1_bf16)

Alternative: Fully FP32 activations throughout

mlp_token_parallel_exact Forward
void mlp_token_parallel_exact(const float * input, const float * W_fc1, const float * b_fc1, const float * W_fc2, const float * b_fc2, float * fc1_output, float * output, int T, int aligned_dim, int num_threads)

Sigmoid Activation (5)
sigmoid_backward Backward
void sigmoid_backward(const float * input, const float * d_output, float * d_input, size_t n)

Backward pass / gradient computation

sigmoid_backward_bf16 Backward
void sigmoid_backward_bf16(const uint16_t * input, const uint16_t * d_output, uint16_t * d_input, size_t n, float * scratch_input, float * scratch_d_output, float * scratch_d_input)

Backward pass / gradient computation

sigmoid_forward Forward
void sigmoid_forward(const float * input, float * output, size_t n)

Forward pass computation

sigmoid_forward_bf16 Forward
void sigmoid_forward_bf16(const uint16_t * input, uint16_t * output, size_t n, float * scratch_input, float * scratch_output)

Forward pass computation

sigmoid_scalar Forward
float sigmoid_scalar(float x)

SwiGLU Activation (7)
ck_test_swiglu Forward
void ck_test_swiglu(const float * gate_up, float * output, int n_tokens, int intermediate_dim)

SwiGLU activation.

swiglu_backward Backward
void swiglu_backward(const float * input, const float * d_output, float * d_input, int tokens, int dim)

SwiGLU backward pass Testtest_swiglu.py::TestSwiGLUBackward::test_backward_tokens test_swiglu.py::TestSwiGLUBackward::test_backward_single test_parity.py::test_swiglu_backward_parity

swiglu_backward_bf16 Backward
void swiglu_backward_bf16(const uint16_t * input, const uint16_t * d_output, uint16_t * d_input, int tokens, int dim)

Backward pass / gradient computation

swiglu_backward_exact Backward
void swiglu_backward_exact(const float * input, const float * d_output, float * d_input, int tokens, int dim)

SwiGLU backward pass (exact version using stdlib sigmoid) Testtest_swiglu.py::TestSwiGLUBackward::test_exact_vs_fast test_swiglu.py::TestSwiGLUBackward::test_exact_single

swiglu_forward Forward
void swiglu_forward(const float * input, float * output, int tokens, int dim)

SwiGLU forward pass Testtest_swiglu.py::TestSwiGLUForward::test_forward_tokens test_swiglu.py::TestSwiGLUForward::test_forward_single test_mlp.py::TestMLPForward::test_swiglu_mlp test_fused_swiglu_decode.py::TestFusedSwiGLUDecode::test_fused_swiglu_decode test_parity.py::test_swiglu_parity

swiglu_forward_bf16 Forward
void swiglu_forward_bf16(const uint16_t * input, uint16_t * output, int tokens, int dim)

Forward pass computation

swiglu_forward_exact Forward
void swiglu_forward_exact(const float * input, float * output, int tokens, int dim)

SwiGLU forward pass (exact version using stdlib sigmoid) Testtest_swiglu.py::TestSwiGLUForward::test_exact_vs_fast test_swiglu.py::TestSwiGLUForward::test_exact_single

RoPE (Rotary Position Embedding) (22)
apply_rope Forward
void apply_rope(float * x, int seq_len, int head_dim)

apply_rope_inline Forward
void apply_rope_inline(float * q, float * k, const float * rope_cos, const float * rope_sin, int pos, int H, int KV, int AD)

ck_model_precompute_rope Forward
void ck_model_precompute_rope(void * model)

Precompute RoPE cos/sin caches. Call once after allocation, before inference.

ck_test_rope Forward
void ck_test_rope(float * q, float * k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)

RoPE (Rotary Position Embedding)

ck_test_rope_interleaved Forward
void ck_test_rope_interleaved(float * q, float * k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)

RoPE with interleaved format (for llama.cpp compatibility)

fused_rope_inplace Forward
void fused_rope_inplace(float * q, float * k, const float * rope_cos, const float * rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int max_seq)

Fused RoPE application (in-place on pre-allocated buffers)

mega_fuse_rope_inplace_avx Forward
void mega_fuse_rope_inplace_avx(float * q, float * k, const float * rope_cos, const float * rope_sin, int pos, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim)

Apply RoPE to Q and K (in-place, from L1)

model_precompute_rope Forward
void model_precompute_rope(MODELModel * model)

qwen2_0_5b_decode_precompute_rope Forward
void qwen2_0_5b_decode_precompute_rope(QWEN2_0_5B_DECODEModel * model)

rope_apply_head Forward
void rope_apply_head(float * x, const float * cos_cache, const float * sin_cache, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)

rope_backward Backward
void rope_backward(const float * d_out, float * d_x, const float * cos_cache, const float * sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)

RoPE backward (inverse rotation) Testtest_rope.py::TestRoPEBackward::test_rope_backward test_rope.py::TestRoPEBackward::test_rope_backward_vs_separate

rope_backward_bf16 Backward
void rope_backward_bf16(const uint16_t * d_out, uint16_t * d_x, const float * cos_cache, const float * sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float * scratch_d_out, float * scratch_d_x)

Backward pass / gradient computation

rope_backward_inplace Backward
void rope_backward_inplace(float * d_x, const float * cos_cache, const float * sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)

RoPE backward in-place (overwrite with inverse rotation) Testtest_rope.py::TestRoPEBackward::test_rope_backward_inplace

rope_backward_qk Backward
void rope_backward_qk(const float * d_q_out, const float * d_k_out, float * d_q, float * d_k, const float * cos_cache, const float * sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)

RoPE backward for both dQ and dK Testtest_rope.py::TestRoPEBackward::test_rope_backward_qk

rope_backward_qk_bf16 Backward
void rope_backward_qk_bf16(const uint16_t * d_q_out, const uint16_t * d_k_out, uint16_t * d_q, uint16_t * d_k, const float * cos_cache, const float * sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float * scratch_dq_out, float * scratch_dq, float * scratch_dk_out, float * scratch_dk)

Backward pass / gradient computation

rope_forward Forward
void rope_forward(float * x, const float * cos_cache, const float * sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)

RoPE forward (head-major layout, in-place) Testtest_rope.py::TestRoPEForward::test_rope_forward test_rope.py::TestRoPEForward::test_rope_vs_separate test_parity.py::test_rope_parity

rope_forward_bf16 Forward
void rope_forward_bf16(uint16_t * x, const float * cos_cache, const float * sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float * scratch)

Forward pass computation

rope_forward_qk Forward
void rope_forward_qk(float * q, float * k, const float * cos_cache, const float * sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)

RoPE forward for both Q and K (common inference pattern) Testtest_rope.py::TestRoPEForward::test_rope_forward_qk test_fused_attention_decode.py::TestFusedAttentionDecode::test_qk_rope test_parity.py::test_rope_qk_parity

rope_forward_qk_bf16 Forward
void rope_forward_qk_bf16(uint16_t * q, uint16_t * k, const float * cos_cache, const float * sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float * scratch_q, float * scratch_k)

Forward pass computation

rope_forward_qk_strided Forward
void rope_forward_qk_strided(float * q, float * k, const float * cos_cache, const float * sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)

RoPE forward for both Q and K with custom strides (KV cache layouts) Testtest_rope.py::TestRoPEForward::test_rope_forward_qk_strided test_kv_cache_attention.py::TestKVCacheAttention::test_qk_rope_strided

rope_forward_strided Forward
void rope_forward_strided(float * x, const float * cos_cache, const float * sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int head_stride_tokens)

RoPE forward with custom head stride (for KV cache layouts) Testtest_rope.py::TestRoPEForward::test_rope_strided test_kv_cache_attention.py::TestKVCacheAttention::test_rope_decode

rope_precompute_cache
void rope_precompute_cache(float * cos_cache, float * sin_cache, int max_seq_len, int head_dim, float base)

Precompute RoPE cos/sin cache Testtest_rope.py::TestRoPECache::test_cache_computation test_rope.py::TestRoPECache::test_cache_values

Fully Connected Layers (2)
fc1_backward_kernel Backward
void fc1_backward_kernel(const float * d_output, const float * fc1_input, const float * W_fc1, float * d_input, float * d_W_fc1, float * d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads)

Backward pass / gradient computation

fc2_backward_kernel Backward
void fc2_backward_kernel(const float * d_output, const float * fc2_input, const float * W_fc2, float * d_input, float * d_W_fc2, float * d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads)

Backward pass / gradient computation

Other Functions (877)
_Static_assert Forward
 _Static_assert(sizeof(MagicHeader), "MagicHeader must be 64 bytes")

__attribute__ Forward
struct __attribute__((packed))

adamw_update_bf16 Forward
void adamw_update_bf16(const uint16_t * grad, uint16_t * weight, float * m, float * v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)

AdamW optimizer update (bf16 weights/gradients, fp32 optimizer state)

adamw_update_f32 Forward
void adamw_update_f32(const float * grad, float * weight, float * m, float * v, size_t numel, float lr, float beta1, float beta2, float eps, float weight_decay, int step)

AdamW optimizer update (fp32 version)

add_backward_bf16 Backward
void add_backward_bf16(const uint16_t * d_y, uint16_t * d_a, uint16_t * d_b, size_t n)

Backward pass / gradient computation

add_bias_tile Forward
void add_bias_tile(float * out, const float * bias, int tile_m, int out_dim)

add_forward_2d_bf16 Forward
void add_forward_2d_bf16(const uint16_t * a, const uint16_t * b, uint16_t * y, int tokens, int dim, int aligned_dim)

Forward pass computation

add_forward_bf16 Forward
void add_forward_bf16(const uint16_t * a, const uint16_t * b, uint16_t * y, size_t n)

Forward pass computation

add_forward_f32 Forward
void add_forward_f32(const float * a, const float * b, float * y, size_t n)

Element-wise add: y = a + b Testtest_add.py::TestAddForward::test_add_forward_f32 test_add.py::TestAddForward::test_add_inplace_f32 test_multi_layer_parity.py::TestMultiLayerParity::test_residual_add

add_inplace_bf16 Forward
void add_inplace_bf16(uint16_t * a, const uint16_t * b, size_t n)

add_inplace_f32 Forward
void add_inplace_f32(float * a, const float * b, size_t n)

add_scaled_forward_bf16 Forward
void add_scaled_forward_bf16(const uint16_t * a, const uint16_t * b, uint16_t * y, float alpha, size_t n)

Forward pass computation

add_scaled_inplace_bf16 Forward
void add_scaled_inplace_bf16(uint16_t * a, const uint16_t * b, float alpha, size_t n)

align_up Forward
size_t align_up(size_t n, size_t align)

align_up_bytes Forward
size_t align_up_bytes(size_t n, size_t align)

align_up_elems Forward
size_t align_up_elems(size_t elems, size_t elem_bytes, size_t align_bytes)

align_up_size Forward
size_t align_up_size(size_t value, size_t align)

amx_available Forward
bool amx_available(void)

apply_bpe_merges Forward
int apply_bpe_merges(CKTrueBPE * bpe, CKBPETokenList * list)

apply_chat_template Forward
char * apply_chat_template(const ChatTemplate * tmpl, const char * system, const char * user)

arena_for_role Forward
CKMemArenaKind arena_for_role(CKBufferRole role)

argmax_f32 Forward
int argmax_f32(const float * scores, int n)

Find index of maximum value.

axpy_2d_f32 Forward
void axpy_2d_f32(float * Y, const float * X, float alpha, int num_tokens, int dim, int y_stride, int x_stride)

Batched AXPY for 2D tensors: Y[t,:] += alpha * X[t,:].

axpy_f32 Forward
void axpy_f32(float * y, const float * x, float alpha, int n)

In-place AXPY: y += alpha * x.

axpy_zero_f32 Forward
void axpy_zero_f32(float * y, const float * x, float alpha, int n)

Zero output then accumulate: y = 0; y += alpha * x.

barrier_init Forward
void barrier_init(ck_barrier_t * b, int n_threads)

barrier_wait Forward
void barrier_wait(ck_barrier_t * b)

Spin-wait barrier. All threads must call this. Uses phase counter to allow re-use without reset.

bf16_tensor_to_float Forward
void bf16_tensor_to_float(const uint16_t * src, float * dst, size_t count)

bf16_to_float Forward
float bf16_to_float(uint16_t v)

buffer_bytes Forward
size_t buffer_bytes(const CKIRV2Buffer * buf, const CKModelConfig * cfg, const CKV2AlignInfo * align)

buffer_enabled Forward
int buffer_enabled(const CKIRV2Graph * graph, const CKIRV2Buffer * buf, int training_enabled)

build_plan Forward
int build_plan(const CKIRV2Graph * graph, CKMemPlan * plan, size_t alignment_bytes, int training_enabled, int tokens_override)

bump_bytes Forward
size_t bump_bytes(size_t * off, size_t bytes, size_t align)

byte_to_gpt2 Forward
int byte_to_gpt2(unsigned char byte, char * out)

cache_compare
int cache_compare(const void * a, const void * b)

ck_add_inplace Forward
void ck_add_inplace(float * dst, const float * src, int tokens, int aligned_embed_dim)

ck_buffer_should_alloc Forward
int ck_buffer_should_alloc(const CKBufferSpec * spec)

ck_buffer_uses_weight_dtype Forward
int ck_buffer_uses_weight_dtype(const CKBufferSpec * spec)

ck_build_decoder_backward_ir Backward
int ck_build_decoder_backward_ir(const CKIRGraph * forward, CKIRGraph * backward)

Build a naive backward IR graph from a forward decoder IR.

ck_build_decoder_ir Forward
int ck_build_decoder_ir(const CKModelConfig * cfg, CKIRGraph * graph)

Build a simple decoder-only IR graph for the given config.

ck_codegen_c_skeleton Forward
void ck_codegen_c_skeleton(const CKIRGraph * forward, const CKIRGraph * backward, FILE * out)

Emit a C skeleton for forward + backward execution based on the IR.

ck_codegen_emit_runtime Forward
int ck_codegen_emit_runtime(const CKIRGraph * forward, const char * path, CKEmitMode mode)

Emit a C runtime file that stitches kernels for the given forward IR.

ck_codegen_v2_dtype_name Forward
const char * ck_codegen_v2_dtype_name(CKDataType dtype)

ck_codegen_v2_emit_dispatch Forward
void ck_codegen_v2_emit_dispatch(FILE * out, const CKIRV2Graph * graph)

ck_codegen_v2_emit_preamble Forward
int ck_codegen_v2_emit_preamble(FILE * out)

ck_codegen_v2_emit_runtime Forward
int ck_codegen_v2_emit_runtime(const CKIRV2Graph * graph, const char * path, CKEmitMode mode)

Emit a C runtime file from a CKIRV2Graph.

ck_codegen_v2_emit_schedule Forward
void ck_codegen_v2_emit_schedule(FILE * out, const CKIRV2Graph * graph, const char * prefill_runtime, const char * decode_runtime, const char * backward_runtime)

ck_codegen_v2_emit_sections Forward
void ck_codegen_v2_emit_sections(FILE * out, const CKIRV2Graph * graph, const CKMemPlan * prefill_plan, const CKMemPlan * decode_plan, const CKMemPlan * backward_plan)

ck_codegen_v2_emit_struct Forward
void ck_codegen_v2_emit_struct(FILE * out, const CKIRV2Graph * graph, const CKMemPlan * plan, const char * tag)

ck_debug_check_buffer Forward
void ck_debug_check_buffer(const char * stage, const float * buf, int size)

ck_debug_check_q4k_weights Forward
void ck_debug_check_q4k_weights(const char * stage, const void * q4_buf, int num_blocks)

ck_debug_check_q8k Forward
void ck_debug_check_q8k(const char * stage, const void * q8_buf, int num_blocks)

ck_dot_f32 Forward
float ck_dot_f32(const float * a, const float * b, int len)

ck_dtype_block_bytes Forward
size_t ck_dtype_block_bytes(CKDataType dt)

Get bytes per block for quantized types.

ck_dtype_block_size Forward
size_t ck_dtype_block_size(CKDataType dt)

Get the number of elements per quantization block.

ck_dtype_bytes Forward
size_t ck_dtype_bytes(CKDataType dt)

Get bytes per element for non-quantized types.

ck_dtype_is_quantized Forward
int ck_dtype_is_quantized(CKDataType dt)

Check if a data type is block-quantized (GGML-style)

ck_dtype_row_bytes Forward
size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)

Calculate total bytes for n_elements of given dtype.

ck_dtype_supported Forward
int ck_dtype_supported(CKDataTypeMask mask, CKDataType dt)

ck_expf Forward
float ck_expf(float x)

ck_fast_expf Forward
float ck_fast_expf(float x)

ck_find_buffer_spec Forward
const CKBufferSpec * ck_find_buffer_spec(const char * name)

ck_find_kernel_spec Forward
const CKKernelSpec * ck_find_kernel_spec(const char * name)

ck_first_layer_buffer_name Forward
const char * ck_first_layer_buffer_name(void)

ck_flash_attn_choose_tile_k Forward
int ck_flash_attn_choose_tile_k(int D_h)

ck_flash_attn_fast_exp_kind Forward
int ck_flash_attn_fast_exp_kind(void)

ck_flash_attn_tile_k Forward
int ck_flash_attn_tile_k(int D_h)

ck_fma_f32_to_f16 Forward
void ck_fma_f32_to_f16(const float * a, const float * b, const float * c, uint16_t * dst, int n)

FMA in FP32, store result as FP16: dst = a * b + c.

ck_fp16_to_fp32 Forward
float ck_fp16_to_fp32(ck_half h)

ck_fp16_to_fp32_2d Forward
void ck_fp16_to_fp32_2d(const uint16_t * src, float * dst, int rows, int cols, int src_stride, int dst_stride)

Convert 2D FP16 matrix to FP32 with strided access.

ck_fp16_to_fp32_row Forward
void ck_fp16_to_fp32_row(const uint16_t * src, float * dst, int n)

Convert FP16 row to FP32 (auto-select best implementation)

ck_fp16_to_fp32_scalar Forward
float ck_fp16_to_fp32_scalar(uint16_t h)

ck_fp16_to_fp32_soft Forward
float ck_fp16_to_fp32_soft(ck_half h)

Convert FP16 (ck_half) to FP32 — software implementation.

ck_fp32_to_fp16 Forward
ck_half ck_fp32_to_fp16(float f)

ck_fp32_to_fp16_2d Forward
void ck_fp32_to_fp16_2d(const float * src, uint16_t * dst, int rows, int cols, int src_stride, int dst_stride)

Convert 2D FP32 matrix to FP16 with strided access.

ck_fp32_to_fp16_inplace Forward
void ck_fp32_to_fp16_inplace(float * data, void * scratch, int n)

Convert FP32 to FP16 in-place using scratch buffer.

ck_fp32_to_fp16_row Forward
void ck_fp32_to_fp16_row(const float * src, uint16_t * dst, int n)

Convert FP32 row to FP16 (auto-select best implementation)

ck_fp32_to_fp16_scalar Forward
uint16_t ck_fp32_to_fp16_scalar(float f)

ck_fp32_to_fp16_soft Forward
ck_half ck_fp32_to_fp16_soft(float f)

Convert FP32 to FP16 (ck_half) — software implementation.

ck_get_block_q4_k_size Forward
int ck_get_block_q4_k_size(void)

Get Q4_K block size in bytes.

ck_get_block_q5_1_size Forward
int ck_get_block_q5_1_size(void)

Get Q5_1 block size in bytes (24 bytes per 32 weights)

ck_get_block_q5_k_size Forward
int ck_get_block_q5_k_size(void)

Get Q5_K block size in bytes (176 bytes per 256 weights)

ck_get_block_q6_k_size Forward
int ck_get_block_q6_k_size(void)

Get Q6_K block size in bytes.

ck_get_block_q8_k_size Forward
int ck_get_block_q8_k_size(void)

Get Q8_K block size in bytes.

ck_get_capabilities Forward
ck_capability_t ck_get_capabilities(void)

Get current platform capabilities.

ck_get_num_threads Forward
int ck_get_num_threads(void)

ck_get_physical_cores Forward
int ck_get_physical_cores(void)

ck_get_qk5_1 Forward
int ck_get_qk5_1(void)

Get QK5_1 (elements per Q5_1 block)

ck_get_qk_k Forward
int ck_get_qk_k(void)

Get QK_K (elements per super-block)

ck_get_threadpool Forward
ck_threadpool_t * ck_get_threadpool(void)

Get the global thread pool handle for dispatch. Convenience wrapper — initializes on first call.

ck_huge_alloc Forward
void * ck_huge_alloc(size_t bytes)

Allocate a large, contiguous memory region for model weights/activations.

ck_huge_free Forward
void ck_huge_free(void * ptr, size_t bytes)

Free memory allocated by ck_huge_alloc.

ck_ir_dump Forward
void ck_ir_dump(const CKIRGraph * graph, FILE * out)

Dump a human-readable view of the IR to the given stream.

ck_ir_free Forward
void ck_ir_free(CKIRGraph * graph)

Free any heap-allocated memory owned by the graph.

ck_ir_parse_json Forward
int ck_ir_parse_json(const char * path, CKIRGraph * graph)

Parse a JSON IR map file (as produced by ck_ir_serialize_json) back into a CKIRGraph. This enables a two-stage pipeline:

ck_ir_serialize_json Forward
int ck_ir_serialize_json(const CKIRGraph * graph, const char * path)

Serialize a CKIRGraph to a simple JSON IR map file.

ck_ir_v2_align_up_bytes Forward
size_t ck_ir_v2_align_up_bytes(size_t n, size_t align)

ck_ir_v2_align_up_elems Forward
size_t ck_ir_v2_align_up_elems(size_t elems, size_t elem_bytes, size_t align_bytes)

ck_ir_v2_apply_meta Forward
int ck_ir_v2_apply_meta(const char * path, CKIRV2Graph * graph)

ck_ir_v2_apply_weight_dtypes Forward
int ck_ir_v2_apply_weight_dtypes(const char * json, const char * end, CKIRV2Graph * graph)

ck_ir_v2_build_decoder Forward
int ck_ir_v2_build_decoder(const CKModelConfig * cfg, CKIRV2Graph * graph)

ck_ir_v2_build_decoder_backward Backward
int ck_ir_v2_build_decoder_backward(const CKIRV2Graph * forward, CKIRV2Graph * backward)

Backward pass / gradient computation

ck_ir_v2_copy_buffer_spec Forward
int ck_ir_v2_copy_buffer_spec(const CKBufferSpec * spec, CKIRV2Buffer * out)

ck_ir_v2_copy_shape Forward
void ck_ir_v2_copy_shape(CKDimToken * dst, const CKDimToken * src)

ck_ir_v2_dim_kind_from_name Forward
CKDimKind ck_ir_v2_dim_kind_from_name(const char * name)

ck_ir_v2_dim_name Forward
const char * ck_ir_v2_dim_name(CKDimKind dim)

ck_ir_v2_dtype_name Forward
const char * ck_ir_v2_dtype_name(CKDataType dtype)

ck_ir_v2_emit_dimensions Forward
void ck_ir_v2_emit_dimensions(FILE * out, const CKModelConfig * cfg, const CKIRV2AlignInfo * align, int tokens_override)

ck_ir_v2_emit_memory_plan Forward
void ck_ir_v2_emit_memory_plan(FILE * out, const CKIRV2Graph * graph, const CKMemPlan * plan)

ck_ir_v2_emit_resolved_shape Forward
void ck_ir_v2_emit_resolved_shape(FILE * out, const CKModelConfig * cfg, const CKIRV2AlignInfo * align, const CKDimToken * shape, int tokens_override)

ck_ir_v2_emit_shape Forward
int ck_ir_v2_emit_shape(FILE * out, const CKDimToken * shape)

ck_ir_v2_find_array_end Forward
const char * ck_ir_v2_find_array_end(const char * open, const char * end)

ck_ir_v2_find_buffer_index Forward
int ck_ir_v2_find_buffer_index(const CKIRV2Graph * graph, const char * name)

ck_ir_v2_find_buffer_spec Forward
const CKBufferSpec * ck_ir_v2_find_buffer_spec(const char * name)

ck_ir_v2_find_kernel_spec Forward
const CKKernelSpec * ck_ir_v2_find_kernel_spec(const char * name)

ck_ir_v2_find_key Forward
const char * ck_ir_v2_find_key(const char * json, const char * key, const char * end)

ck_ir_v2_free Forward
void ck_ir_v2_free(CKIRV2Graph * graph)

ck_ir_v2_free_buffer Forward
void ck_ir_v2_free_buffer(CKIRV2Buffer * buf)

ck_ir_v2_free_node Forward
void ck_ir_v2_free_node(CKIRV2Node * node)

ck_ir_v2_lower_copy_buffers Forward
int ck_ir_v2_lower_copy_buffers(const CKIRV2Graph * input, CKIRV2Graph * output)

ck_ir_v2_lower_copy_nodes Forward
int ck_ir_v2_lower_copy_nodes(const CKIRV2Graph * input, CKIRV2LowerMode mode, CKIRV2Graph * output)

ck_ir_v2_lower_emit_json Forward
int ck_ir_v2_lower_emit_json(const CKIRV2Graph * input, CKIRV2LowerMode mode, const char * path)

ck_ir_v2_lower_graph Forward
int ck_ir_v2_lower_graph(const CKIRV2Graph * input, CKIRV2LowerMode mode, CKIRV2Graph * output, CKMemPlan * plan)

ck_ir_v2_lower_mode_from_string Forward
int ck_ir_v2_lower_mode_from_string(const char * name, CKIRV2LowerMode * out_mode)

ck_ir_v2_lower_mode_name Forward
const char * ck_ir_v2_lower_mode_name(CKIRV2LowerMode mode)

ck_ir_v2_lower_node_enabled Forward
int ck_ir_v2_lower_node_enabled(const CKIRV2Node * node, CKIRV2LowerMode mode)

ck_ir_v2_lower_strdup Forward
char * ck_ir_v2_lower_strdup(const char * s)

ck_ir_v2_mem_arena_name Forward
const char * ck_ir_v2_mem_arena_name(CKMemArenaKind arena)

ck_ir_v2_next_object Forward
const char * ck_ir_v2_next_object(const char * cur, const char * end, const char ** obj_start, const char ** obj_end)

ck_ir_v2_parse_bindings Forward
int ck_ir_v2_parse_bindings(const char * obj_start, const char * obj_end, CKIRV2Graph * graph, CKIRV2Node * node)

ck_ir_v2_parse_bool Forward
int ck_ir_v2_parse_bool(const char * json, const char * key, const char * end, int * out_val)

ck_ir_v2_parse_buffers Forward
int ck_ir_v2_parse_buffers(const char * json, const char * end, CKIRV2Graph * graph)

ck_ir_v2_parse_dim_kind Forward
CKDimKind ck_ir_v2_parse_dim_kind(const char * obj_start, const char * obj_end)

ck_ir_v2_parse_dtype Forward
CKDataType ck_ir_v2_parse_dtype(const char * s)

ck_ir_v2_parse_float Forward
int ck_ir_v2_parse_float(const char * json, const char * key, const char * end, float * out_val)

ck_ir_v2_parse_int Forward
int ck_ir_v2_parse_int(const char * json, const char * key, const char * end, int * out_val)

ck_ir_v2_parse_json Forward
int ck_ir_v2_parse_json(const char * path, CKIRV2Graph * graph)

ck_ir_v2_parse_nodes Forward
int ck_ir_v2_parse_nodes(const char * json, const char * end, CKIRV2Graph * graph)

ck_ir_v2_parse_role Forward
CKBufferRole ck_ir_v2_parse_role(const char * s)

ck_ir_v2_parse_scope Forward
CKBufferScope ck_ir_v2_parse_scope(const char * s)

ck_ir_v2_parse_shape Forward
int ck_ir_v2_parse_shape(const char * obj_start, const char * obj_end, CKDimToken * shape_out)

ck_ir_v2_parse_string Forward
int ck_ir_v2_parse_string(const char * start, const char * end, char ** out_str)

ck_ir_v2_parse_string_field Forward
int ck_ir_v2_parse_string_field(const char * json, const char * key, const char * end, char ** out_str)

ck_ir_v2_resolve_align Forward
void ck_ir_v2_resolve_align(const CKModelConfig * cfg, size_t alignment_bytes, CKIRV2AlignInfo * align)

ck_ir_v2_resolve_dim_value Forward
size_t ck_ir_v2_resolve_dim_value(const CKModelConfig * cfg, const CKIRV2AlignInfo * align, CKDimKind dim, int tokens_override)

ck_ir_v2_role_name Forward
const char * ck_ir_v2_role_name(CKBufferRole role)

ck_ir_v2_scope_name Forward
const char * ck_ir_v2_scope_name(CKBufferScope scope)

ck_ir_v2_select_kernel Forward
const char * ck_ir_v2_select_kernel(const CKKernelSpec * spec, CKDataType dtype, int backward)

ck_ir_v2_serialize_json Forward
int ck_ir_v2_serialize_json(const CKIRV2Graph * graph, const char * path)

ck_ir_v2_serialize_json_internal Forward
int ck_ir_v2_serialize_json_internal(const CKIRV2Graph * graph, const CKMemPlan * plan, const char * mode, int tokens_override, int base_context_window, const char * path)

ck_ir_v2_serialize_json_with_plan Forward
int ck_ir_v2_serialize_json_with_plan(const CKIRV2Graph * graph, const CKMemPlan * plan, const char * mode, int tokens_override, int base_context_window, const char * path)

ck_ir_v2_skip_string Forward
const char * ck_ir_v2_skip_string(const char * cur, const char * end)

ck_ir_v2_skip_ws Forward
const char * ck_ir_v2_skip_ws(const char * cur, const char * end)

ck_ir_v2_strdup Forward
char * ck_ir_v2_strdup(const char * s)

ck_ir_validate_supported Forward
int ck_ir_validate_supported(const CKIRGraph * graph)

ck_layer_debug_enabled Forward
int ck_layer_debug_enabled(void)

ck_load_weights_manifest_v4 Forward
int ck_load_weights_manifest_v4(void * base, const char * weights_path, const char * manifest_path)

Load BUMPWGT4 weights into a v4 model buffer using a manifest map.

ck_mem_plan_build_inference Forward
int ck_mem_plan_build_inference(const CKIRV2Graph * graph, CKMemPlan * plan, size_t alignment_bytes)

ck_mem_plan_build_inference_with_tokens Forward
int ck_mem_plan_build_inference_with_tokens(const CKIRV2Graph * graph, CKMemPlan * plan, size_t alignment_bytes, int tokens_override)

ck_mem_plan_build_training Forward
int ck_mem_plan_build_training(const CKIRV2Graph * graph, CKMemPlan * plan, size_t alignment_bytes)

ck_mem_plan_build_training_with_tokens Forward
int ck_mem_plan_build_training_with_tokens(const CKIRV2Graph * graph, CKMemPlan * plan, size_t alignment_bytes, int tokens_override)

ck_mem_plan_free Forward
void ck_mem_plan_free(CKMemPlan * plan)

ck_memory_allocate Forward
int ck_memory_allocate(CKModel * model, int use_hugepages)

Allocate the planned memory.

ck_memory_free Forward
void ck_memory_free(CKModel * model)

Free the model memory.

ck_memory_plan Forward
size_t ck_memory_plan(const CKSectionConfig * sections, int num_sections, int mode, uint32_t fusion_flags, CKModel * out_model)

Plan memory layout for a model.

ck_metrics_cleanup Forward
void ck_metrics_cleanup(void)

Cleanup and free resources

ck_metrics_create_context Forward
CKMetricsContext * ck_metrics_create_context(void)

ck_metrics_ctx_end Forward
void ck_metrics_ctx_end(CKMetricsContext * ctx, const char * status)

ck_metrics_ctx_init Forward
bool ck_metrics_ctx_init(CKMetricsContext * ctx, const char * run_id, const char * endpoint, CKMetricsMode mode)

ck_metrics_ctx_log_f Forward
void ck_metrics_ctx_log_f(CKMetricsContext * ctx, const char * name, double value)

ck_metrics_ctx_log_i Forward
void ck_metrics_ctx_log_i(CKMetricsContext * ctx, const char * name, int64_t value)

ck_metrics_ctx_step Forward
void ck_metrics_ctx_step(CKMetricsContext * ctx, int64_t step)

ck_metrics_destroy_context Forward
void ck_metrics_destroy_context(CKMetricsContext * ctx)

ck_metrics_end Forward
void ck_metrics_end(const char * status)

End the training run

ck_metrics_generate_run_id Forward
void ck_metrics_generate_run_id(char * buffer, size_t size)

Generate a unique run ID based on current timestamp

ck_metrics_get_memory_mb Forward
int64_t ck_metrics_get_memory_mb(void)

Get current memory usage in MB (platform-specific)

ck_metrics_init Forward
bool ck_metrics_init(const char * run_id, const char * endpoint, CKMetricsMode mode)

Initialize metrics logging

ck_metrics_init_full Forward
bool ck_metrics_init_full(const char * run_id, const char * endpoint, CKMetricsMode mode, const char * model, const char * dataset, int batch_size, double lr, int max_steps)

Initialize with full configuration

ck_metrics_log_f Forward
void ck_metrics_log_f(const char * name, double value)

Log a float metric (e.g., loss, learning rate)

ck_metrics_log_i Forward
void ck_metrics_log_i(const char * name, int64_t value)

Log an integer metric (e.g., step, tokens_per_sec)

ck_metrics_log_s Forward
void ck_metrics_log_s(const char * name, const char * value)

Log a string metric (e.g., phase, status)

ck_metrics_step Forward
void ck_metrics_step(int64_t step)

Flush metrics for the current step and advance to next step Call this at the end of each training step

ck_metrics_timestamp Forward
double ck_metrics_timestamp(void)

Get current timestamp in seconds with microsecond precision

ck_min Forward
int ck_min(int a, int b)

ck_min_i Forward
int ck_min_i(int a, int b)

ck_model_allocate Forward
int ck_model_allocate(CKModel * model, int hugepage_mode)

Allocate the planned memory. hugepage_mode 0=normal, 1=2MB hugepages, 2=1GB hugepages 0 on success, -1 on failure

ck_model_config_from_hf_json Forward
int ck_model_config_from_hf_json(const char * path, CKModelConfig * cfg)

Parse a HuggingFace-style config.json into CKModelConfig.

ck_model_create Forward
void * ck_model_create(void)

Create and allocate model memory. Returns opaque model pointer, or NULL on failure.

ck_model_decode Forward
void ck_model_decode(void * model, const int * token, int token_index)

Decode single token at position token_index. Used for autoregressive generation.

ck_model_forward Forward
void ck_model_forward(void * model, const int * tokens, int num_tokens)

Forward pass (prefill) - process multiple tokens. Used for initial prompt processing.

ck_model_free Forward
void ck_model_free(void * model)

Free model memory.

ck_model_get_base Forward
void * ck_model_get_base(void * model)

Get model base pointer (for weight loading).

ck_model_get_config Forward
const CKModelConfig * ck_model_get_config(void)

Get model configuration (dimensions, sizes, etc.) This is available before allocation.

ck_model_get_logits Forward
float * ck_model_get_logits(void * model)

Get pointer to output logits buffer. Size is vocab_size floats.

ck_model_get_total_bytes Forward
size_t ck_model_get_total_bytes(void * model)

Get total model size in bytes.

ck_model_load_weights Forward
int ck_model_load_weights(void * model, const char * bump_path)

Load weights from BUMP file into model. Returns 0 on success, -1 on failure.

ck_model_load_weights_flat Forward
int ck_model_load_weights_flat(TransformerModel * m, const char * path)

Load weights from a single flat binary file into model->memory_base.

ck_model_plan Forward
size_t ck_model_plan(CKModel * model, const CKSectionConfig * configs, int num_sections, int training_enabled, uint32_t fusion_flags)

Plan memory layout for complete model. Returns total bytes needed.

ck_model_verify_canaries Forward
int ck_model_verify_canaries(void * model)

Verify memory canaries (debug). Returns number of corrupted canaries (0 = OK).

ck_murmurhash3 Forward
uint32_t ck_murmurhash3(const char * key, uint32_t len, uint32_t seed)

MurmurHash3-32bit hash function (original HPC_Embeddings version).

ck_murmurhash3_128 Forward
void ck_murmurhash3_128(const void * key, size_t len, uint32_t seed, uint64_t * out1, uint64_t * out2)

MurmurHash3-128bit hash function (produces two 64-bit values).

ck_murmurhash3_32 Forward
uint32_t ck_murmurhash3_32(const void * key, size_t len, uint32_t seed)

MurmurHash3-32bit hash function (alternative name).

ck_murmurhash3_str Forward
uint32_t ck_murmurhash3_str(const char * key, uint32_t seed)

ck_murmurhash3_strn Forward
uint32_t ck_murmurhash3_strn(const char * str, size_t len, uint32_t seed)

MurmurHash3-32bit for string with length.

ck_nearest_int Forward
int ck_nearest_int(float fval)

ck_nearest_int_fused Forward
int ck_nearest_int_fused(float fval)

ck_op_name Forward
const char * ck_op_name(CKOpType op)

ck_op_supported Forward
int ck_op_supported(CKOpType op)

ck_parse_env_int Forward
int ck_parse_env_int(const char * name)

ck_plan_step_enabled Forward
int ck_plan_step_enabled(const CKPlanStep * step, const CKIRGraph * cfg)

ck_pool_alloc Forward
void * ck_pool_alloc(CKMemPool * pool, size_t size)

ck_pool_free Forward
void ck_pool_free(CKMemPool * pool)

ck_pool_init Forward
void ck_pool_init(CKMemPool * pool)

ck_pool_strdup Forward
char * ck_pool_strdup(CKMemPool * pool, const char * s, int len)

ck_prefill_forward Forward
int ck_prefill_forward(const void * weights, const int32_t * tokens, int n_tokens, float * hidden_out, void * kv_cache, int kv_pos)

Forward pass computation

ck_q8_0_outproj_enabled Forward
int ck_q8_0_outproj_enabled(void)

ck_q8k_activations_enabled Forward
int ck_q8k_activations_enabled(void)

ck_qkv_project_head_major Forward
void ck_qkv_project_head_major(const float * input, const float * wq, const float * bq, const float * wk, const float * bk, const float * wv, const float * bv, float * q, float * k, float * v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)

ck_qkv_project_head_major_backward Backward
void ck_qkv_project_head_major_backward(const float * d_q, const float * d_k, const float * d_v, const float * input, const float * wq, const float * bq, const float * wk, const float * bk, const float * wv, const float * bv, float * d_input, float * d_wq, float * d_bq, float * d_wk, float * d_bk, float * d_wv, float * d_bv, float * scratch, int tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim, int num_threads)

Backward pass / gradient computation

ck_qkv_project_head_major_q4_k Forward
void ck_qkv_project_head_major_q4_k(const float * input, const void * wq, const float * bq, const void * wk, const float * bk, const void * wv, const float * bv, float * q, float * k, float * v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)

ck_qkv_project_head_major_q4_k_q8_k Forward
void ck_qkv_project_head_major_q4_k_q8_k(const float * input, const void * wq, const float * bq, const void * wk, const float * bk, const void * wv, const float * bv, float * q, float * k, float * v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)

ck_qkv_project_head_major_quant Forward
void ck_qkv_project_head_major_quant(const float * input, const void * wq, const float * bq, CKDataType wq_dtype, const void * wk, const float * bk, CKDataType wk_dtype, const void * wv, const float * bv, CKDataType wv_dtype, float * q, float * k, float * v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)

ck_qkv_project_head_major_ref Forward
void ck_qkv_project_head_major_ref(const float * input, const float * wq, const float * bq, const float * wk, const float * bk, const float * wv, const float * bv, float * q, float * k, float * v, int tokens, int kv_stride_tokens, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)

ck_qkv_project_head_major_token Forward
void ck_qkv_project_head_major_token(const float * input_row, const float * wq, const float * bq, const float * wk, const float * bk, const float * wv, const float * bv, float * q_token, float * k_token, float * v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)

ck_qkv_project_head_major_token_q4_k Forward
void ck_qkv_project_head_major_token_q4_k(const float * input_row, const void * wq, const float * bq, const void * wk, const float * bk, const void * wv, const float * bv, float * q_token, float * k_token, float * v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)

ck_qkv_project_head_major_token_q4_k_q8_k Forward
void ck_qkv_project_head_major_token_q4_k_q8_k(const block_q8_K * input_q8, const void * wq, const float * bq, const void * wk, const float * bk, const void * wv, const float * bv, float * q_token, float * k_token, float * v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)

ck_qkv_project_head_major_token_quant Forward
void ck_qkv_project_head_major_token_quant(const float * input_row, const void * wq, const float * bq, CKDataType wq_dtype, const void * wk, const float * bk, CKDataType wk_dtype, const void * wv, const float * bv, CKDataType wv_dtype, float * q_token, float * k_token, float * v_token, int aligned_embed_dim, int num_heads, int num_kv_heads, int aligned_head_dim)

ck_quant_block_size Forward
size_t ck_quant_block_size(int type)

Get the block size (number of weights per block) for a quant type.

ck_quant_row_size Forward
size_t ck_quant_row_size(int type, int64_t n_elements)

Calculate total bytes needed for n_elements with given quant type.

ck_quant_type_size Forward
size_t ck_quant_type_size(int type)

Get the byte size per block for a quant type.

ck_residual_add_backward Backward
void ck_residual_add_backward(const float * d_out, float * d_a, float * d_b, int tokens, int aligned_embed_dim)

Backward pass / gradient computation

ck_residual_add_token_major Forward
void ck_residual_add_token_major(const float * a, const float * b, float * out, int tokens, int aligned_embed_dim)

ck_round_nearest Forward
int ck_round_nearest(float v)

Round to nearest int, half away from zero (matches quantize_row_q8_0)

ck_scale_f32_to_f16 Forward
void ck_scale_f32_to_f16(const float * src, float scale, uint16_t * dst, int n)

Scale FP32 array and store as FP16: dst = scale * src.

ck_section_config_init Forward
void ck_section_config_init(CKSectionConfig * config, size_t simd_align)

Initialize section config with computed alignments.

ck_section_plan Forward
size_t ck_section_plan(CKSection * section, const CKSectionConfig * config, int training_enabled, size_t base_offset)

Plan memory layout for a single section. Returns bytes needed for this section.

ck_set_num_threads Forward
void ck_set_num_threads(int num_threads)

ck_set_strict_parity Forward
void ck_set_strict_parity(int enabled)

ck_strict_parity_enabled Forward
int ck_strict_parity_enabled(void)

ck_test_dequant_q4_0 Forward
void ck_test_dequant_q4_0(const void * src, float * dst, int n)

Dequantize Q4_0 data to FP32.

ck_test_dequant_q4_k Forward
void ck_test_dequant_q4_k(const void * src, float * dst, int n)

Dequantize Q4_K data to FP32.

ck_test_dequant_q5_1 Forward
void ck_test_dequant_q5_1(const void * src, float * dst, int n)

Dequantize Q5_1 data to FP32.

ck_test_dequant_q6_k Forward
void ck_test_dequant_q6_k(const void * src, float * dst, int n)

Dequantize Q6_K data to FP32.

ck_test_geglu Forward
void ck_test_geglu(const float * x, float * out, int n_tokens, int dim)

Test GeGLU activation.

ck_test_geglu_backward Backward
void ck_test_geglu_backward(const float * x, const float * d_out, float * d_x, int n_tokens, int dim)

Test GeGLU backward.

ck_test_gemv_q4_k Forward
void ck_test_gemv_q4_k(const void * weight_q4k, const float * input_f32, float * output, int cols)

Q4_K GEMV - dot product of quantized weights and FP32 input.

ck_test_gemv_q5_0 Forward
void ck_test_gemv_q5_0(const void * weight_q5_0, const float * input_f32, float * output, int rows, int cols)

Q5_0 GEMV - matrix-vector multiply with Q5_0 weights.

ck_test_gemv_q5_0_q8_0 Forward
void ck_test_gemv_q5_0_q8_0(const void * weight_q5_0, const float * input_f32, float * output, int rows, int cols)

Q5_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.

ck_test_gemv_q5_1 Forward
void ck_test_gemv_q5_1(const void * weight_q5_1, const float * input_f32, float * output, int rows, int cols)

Q5_1 GEMV - matrix-vector multiply with Q5_1 weights (32-element blocks)

ck_test_gemv_q5_k Forward
void ck_test_gemv_q5_k(const void * weight_q5_k, const float * input_f32, float * output, int rows, int cols)

Q5_K GEMV - matrix-vector multiply with Q5_K weights (256-element super-blocks)

ck_test_gemv_q6_k Forward
void ck_test_gemv_q6_k(const void * weight_q6k, const float * input_f32, float * output, int cols)

Q6_K GEMV.

ck_test_gemv_q8_0 Forward
void ck_test_gemv_q8_0(const void * weight_q8_0, const float * input_f32, float * output, int rows, int cols)

Q8_0 GEMV - matrix-vector multiply with Q8_0 weights.

ck_test_gemv_q8_0_q8_0 Forward
void ck_test_gemv_q8_0_q8_0(const void * weight_q8_0, const float * input_f32, float * output, int rows, int cols)

Q8_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.

ck_test_quantize_q8_k Forward
void ck_test_quantize_q8_k(const float * src, void * dst, int n)

Quantize FP32 to Q8_K (for activations)

ck_test_vec_dot_q5_0_q8_0 Forward
void ck_test_vec_dot_q5_0_q8_0(const void * weight_q5_0, const void * input_q8_0, float * output, int cols)

Direct Q5_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)

ck_test_vec_dot_q8_0_q8_0 Forward
void ck_test_vec_dot_q8_0_q8_0(const void * weight_q8_0, const void * input_q8_0, float * output, int cols)

Direct Q8_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)

ck_threadpool_barrier Forward
void ck_threadpool_barrier(ck_threadpool_t * pool)

Barrier synchronization within a dispatched work function.

ck_threadpool_create Forward
ck_threadpool_t * ck_threadpool_create(int n_threads)

Create a thread pool with n_threads total threads. Thread 0 is the calling (main) thread; n_threads-1 workers are spawned.

ck_threadpool_destroy Forward
void ck_threadpool_destroy(ck_threadpool_t * pool)

Destroy the thread pool. Signals all workers to exit and joins them. Safe to call with NULL.

ck_threadpool_dispatch Forward
void ck_threadpool_dispatch(ck_threadpool_t * pool, ck_work_fn_t fn, void * args)

Dispatch work to all threads and wait for completion.

ck_threadpool_global Forward
ck_threadpool_t * ck_threadpool_global(void)

Get or create the global thread pool. Thread-safe (uses pthread_once internally). Uses ck_get_num_threads() for auto-detection.

ck_threadpool_global_destroy Forward
void ck_threadpool_global_destroy(void)

Destroy the global thread pool. Called during engine shutdown.

ck_threadpool_init Forward
void ck_threadpool_init(void)

Initialize the global thread pool. Called once during engine startup (e.g., from ck_model_init). Uses ck_get_num_threads() for thread count (respects CK_NUM_THREADS env).

ck_threadpool_n_threads Forward
int ck_threadpool_n_threads(const ck_threadpool_t * pool)

Get total thread count (including main thread)

ck_threadpool_pause Forward
void ck_threadpool_pause(ck_threadpool_t * pool)

Pause workers — they sleep on condvar (0% CPU). Call between batches or during interactive waiting. Workers wake on next dispatch or resume.

ck_threadpool_resume Forward
void ck_threadpool_resume(ck_threadpool_t * pool)

Resume workers — transition from sleep to spin-wait. Call before starting a new batch of work.

ck_threadpool_shutdown Forward
void ck_threadpool_shutdown(void)

Shut down the global thread pool. Called during engine teardown. Workers are joined and freed.

ck_threadpool_thread_id Forward
int ck_threadpool_thread_id(const ck_threadpool_t * pool)

Get thread index for current thread (0 = main, -1 if not in pool)

ck_tokenizer_add_merge Forward
int ck_tokenizer_add_merge(CKTokenizer * tok, int32_t left, int32_t right, int32_t merged)

ck_tokenizer_add_special_token Forward
int ck_tokenizer_add_special_token(CKTokenizer * tok, const char * name, int32_t id)

Add special token (UNK, BOS, EOS, PAD, MASK).

ck_tokenizer_add_token Forward
int32_t ck_tokenizer_add_token(CKTokenizer * tok, const char * token, int len)

ck_tokenizer_create Forward
CKTokenizer * ck_tokenizer_create(CKTokenizerType type)

ck_tokenizer_create_bpe Forward
CKTokenizer * ck_tokenizer_create_bpe(void)

Create tokenizer with default BPE config.

ck_tokenizer_create_spm Forward
CKTokenizer * ck_tokenizer_create_spm(void)

Create tokenizer with default SPM config.

ck_tokenizer_create_wordpiece Forward
CKTokenizer * ck_tokenizer_create_wordpiece(void)

Create tokenizer with default WordPiece config.

ck_tokenizer_decode Forward
int ck_tokenizer_decode(const CKTokenizer * tok, const int32_t * ids, int num_ids, char * text, int max_len)

ck_tokenizer_detect_space_prefix_style Forward
CKSpacePrefixStyle ck_tokenizer_detect_space_prefix_style(CKTokenizer * tok)

ck_tokenizer_encode Forward
int ck_tokenizer_encode(const CKTokenizer * tok, const char * text, int text_len, int32_t * ids, int max_ids)

ck_tokenizer_encode_spm_impl Forward
int ck_tokenizer_encode_spm_impl(const CKTokenizer * tok, const char * text, int text_len, int32_t * ids, int max_ids)

ck_tokenizer_encode_spm_llama_impl Forward
int ck_tokenizer_encode_spm_llama_impl(const CKTokenizer * tok, const char * text, int text_len, int32_t * ids, int max_ids)

ck_tokenizer_encode_tokens Forward
int ck_tokenizer_encode_tokens(const CKTokenizer * tok, const char * text, int text_len, const char ** out_tokens, int max_tokens)

Encode and return tokens as array of strings.

ck_tokenizer_encode_with_special Forward
int ck_tokenizer_encode_with_special(CKTokenizer * tok, const char * text, int text_len, int32_t * ids, int max_ids, bool add_special)

Encode with special token handling.

ck_tokenizer_free Forward
void ck_tokenizer_free(CKTokenizer * tok)

ck_tokenizer_hash Forward
uint32_t ck_tokenizer_hash(const char * key, size_t len)

ck_tokenizer_hash_str Forward
uint32_t ck_tokenizer_hash_str(const char * key)

ck_tokenizer_hash_table_clear Forward
void ck_tokenizer_hash_table_clear(CKTokenizerHashTable * table, bool free_values)

Clear all entries (but keep bucket array).

ck_tokenizer_hash_table_contains Forward
bool ck_tokenizer_hash_table_contains(CKTokenizerHashTable * table, const char * key)

Check if key exists.

ck_tokenizer_hash_table_count Forward
size_t ck_tokenizer_hash_table_count(CKTokenizerHashTable * table)

Get the number of entries.

ck_tokenizer_hash_table_create Forward
CKTokenizerHashTable * ck_tokenizer_hash_table_create(size_t bucket_count)

Create a hash table.

ck_tokenizer_hash_table_delete Forward
int ck_tokenizer_hash_table_delete(CKTokenizerHashTable * table, const char * key, bool free_value)

Delete a key.

ck_tokenizer_hash_table_free Forward
void ck_tokenizer_hash_table_free(CKTokenizerHashTable * table, bool free_values)

Free a hash table.

ck_tokenizer_hash_table_insert Forward
int ck_tokenizer_hash_table_insert(CKTokenizerHashTable * table, const char * key, void * value)

Insert a key-value pair.

ck_tokenizer_hash_table_iterate Forward
int ck_tokenizer_hash_table_iterate(CKTokenizerHashTable * table, CKTokenizerHashCallback callback, void * user_data)

ck_tokenizer_hash_table_keys Forward
size_t ck_tokenizer_hash_table_keys(CKTokenizerHashTable * table, const char ** out_keys, size_t max_keys)

Get all keys as an array.

ck_tokenizer_hash_table_lookup Forward
void * ck_tokenizer_hash_table_lookup(CKTokenizerHashTable * table, const char * key)

Look up a key.

ck_tokenizer_hash_table_lookup_avx Forward
void * ck_tokenizer_hash_table_lookup_avx(CKTokenizerHashTable * table, const char * key)

ck_tokenizer_id_to_token Forward
const char * ck_tokenizer_id_to_token(const CKTokenizer * tok, int32_t id)

ck_tokenizer_init Forward
int ck_tokenizer_init(CKTokenizer * tok)

ck_tokenizer_load Forward
int ck_tokenizer_load(CKTokenizer * tok, const char * path)

ck_tokenizer_load_binary Forward
int ck_tokenizer_load_binary(CKTokenizer * tok, int vocab_size, const int32_t * offsets, const char * strings, int num_merges, const int32_t * merges)

Load vocabulary from memory-mapped binary data.

ck_tokenizer_load_binary_with_scores Forward
int ck_tokenizer_load_binary_with_scores(CKTokenizer * tok, int vocab_size, const int32_t * offsets, const char * strings, const float * scores, const uint8_t * types, int num_merges, const int32_t * merges)

Load vocabulary from memory-mapped binary data with scores and types.

ck_tokenizer_load_gguf Forward
int ck_tokenizer_load_gguf(CKTokenizer * tok, const char * path)

Load vocabulary from GGUF file.

ck_tokenizer_load_json Forward
int ck_tokenizer_load_json(CKTokenizer * tok, const char * path)

Load vocabulary from JSON file (HuggingFace format).

ck_tokenizer_load_merges Forward
int ck_tokenizer_load_merges(CKTokenizer * tok, const char * path)

Load BPE merges from text file.

ck_tokenizer_load_text Forward
int ck_tokenizer_load_text(CKTokenizer * tok, const char * path)

Load vocabulary from text file (one token per line).

ck_tokenizer_lookup Forward
int32_t ck_tokenizer_lookup(const CKTokenizer * tok, const char * token, int len)

ck_tokenizer_lookup_exact Forward
int32_t ck_tokenizer_lookup_exact(const CKTokenizer * tok, const char * token)

ck_tokenizer_lookup_exact_n Forward
int32_t ck_tokenizer_lookup_exact_n(const CKTokenizer * tok, const char * text, int text_len)

ck_tokenizer_lookup_merge Forward
int ck_tokenizer_lookup_merge(const CKTokenizer * tok, int32_t left, int32_t right)

ck_tokenizer_mempool_alloc Forward
void * ck_tokenizer_mempool_alloc(CKTokenizerMemPool * pool, size_t size)

Allocate from pool.

ck_tokenizer_mempool_alloc_aligned Forward
void * ck_tokenizer_mempool_alloc_aligned(CKTokenizerMemPool * pool, size_t size, size_t align)

Allocate aligned memory from pool.

ck_tokenizer_mempool_alloc_count Forward
size_t ck_tokenizer_mempool_alloc_count(CKTokenizerMemPool * pool)

Get allocation count.

ck_tokenizer_mempool_available Forward
size_t ck_tokenizer_mempool_available(CKTokenizerMemPool * pool)

Get available bytes in pool.

ck_tokenizer_mempool_free Forward
void ck_tokenizer_mempool_free(CKTokenizerMemPool * pool)

Free a memory pool.

ck_tokenizer_mempool_init Forward
int ck_tokenizer_mempool_init(CKTokenizerMemPool * pool, size_t size)

Initialize a memory pool.

ck_tokenizer_mempool_reset Forward
void ck_tokenizer_mempool_reset(CKTokenizerMemPool * pool)

Reset pool (mark all memory as free).

ck_tokenizer_mempool_strdup Forward
char * ck_tokenizer_mempool_strdup(CKTokenizerMemPool * pool, const char * str)

Allocate and copy string (strdup equivalent).

ck_tokenizer_mempool_strndup Forward
char * ck_tokenizer_mempool_strndup(CKTokenizerMemPool * pool, const char * str, int len)

Allocate and copy string with length.

ck_tokenizer_mempool_used Forward
size_t ck_tokenizer_mempool_used(CKTokenizerMemPool * pool)

Get used bytes in pool.

ck_tokenizer_reset Forward
void ck_tokenizer_reset(CKTokenizer * tok)

ck_tokenizer_set_add_bos_eos Forward
void ck_tokenizer_set_add_bos_eos(CKTokenizer * tok, bool add_bos, bool add_eos)

ck_tokenizer_set_add_space_prefix Forward
void ck_tokenizer_set_add_space_prefix(CKTokenizer * tok, bool add_space_prefix)

ck_tokenizer_set_space_prefix_style Forward
void ck_tokenizer_set_space_prefix_style(CKTokenizer * tok, CKSpacePrefixStyle style)

ck_tokenizer_set_special_ids Forward
void ck_tokenizer_set_special_ids(CKTokenizer * tok, int32_t unk, int32_t bos, int32_t eos, int32_t pad, int32_t mask)

ck_tokenizer_set_spm_mode Forward
void ck_tokenizer_set_spm_mode(CKTokenizer * tok, CKSpmMode spm_mode)

ck_tokenizer_set_use_trie Forward
void ck_tokenizer_set_use_trie(CKTokenizer * tok, bool use_trie)

ck_tokenizer_utf8_normalize_nfc Forward
size_t ck_tokenizer_utf8_normalize_nfc(const char * src, size_t src_len, char * dst, size_t dst_size)

ck_tokenizer_vocab_size Forward
int ck_tokenizer_vocab_size(const CKTokenizer * tok)

ck_trie_clear Forward
void ck_trie_clear(CKTrie * trie)

ck_trie_create Forward
CKTrie * ck_trie_create(size_t max_nodes)

ck_trie_find_longest Forward
int32_t ck_trie_find_longest(const CKTrie * trie, const char * text, size_t text_len, size_t start_pos, size_t * match_len)

ck_trie_free Forward
void ck_trie_free(CKTrie * trie)

ck_trie_has_prefix Forward
bool ck_trie_has_prefix(const CKTrie * trie, const char * text, size_t text_len, size_t pos)

ck_trie_insert Forward
int ck_trie_insert(CKTrie * trie, const char * token, int32_t token_id, bool is_special, int32_t priority)

ck_trie_node_count Forward
size_t ck_trie_node_count(const CKTrie * trie)

ck_true_bpe_add_merge Forward
int ck_true_bpe_add_merge(CKTrueBPE * bpe, int32_t left_id, int32_t right_id, int32_t merged_id, int32_t priority)

ck_true_bpe_add_merge_by_tokens Forward
int ck_true_bpe_add_merge_by_tokens(CKTrueBPE * bpe, const char * left, const char * right, int32_t priority)

ck_true_bpe_add_special_token Forward
int ck_true_bpe_add_special_token(CKTrueBPE * bpe, const char * token, int32_t id)

ck_true_bpe_add_token Forward
int ck_true_bpe_add_token(CKTrueBPE * bpe, const char * token, int32_t id, float score)

ck_true_bpe_create Forward
CKTrueBPE * ck_true_bpe_create(void)

ck_true_bpe_decode Forward
int ck_true_bpe_decode(const CKTrueBPE * bpe, const int32_t * ids, int num_ids, char * text, int max_len)

ck_true_bpe_detect_space_style Forward
CKSpacePrefixStyle ck_true_bpe_detect_space_style(CKTrueBPE * bpe)

ck_true_bpe_encode Forward
int ck_true_bpe_encode(CKTrueBPE * bpe, const char * text, int text_len, int32_t * ids, int max_ids)

ck_true_bpe_free Forward
void ck_true_bpe_free(CKTrueBPE * bpe)

ck_true_bpe_id_to_token Forward
const char * ck_true_bpe_id_to_token(const CKTrueBPE * bpe, int32_t id)

ck_true_bpe_load_binary Forward
int ck_true_bpe_load_binary(CKTrueBPE * bpe, int vocab_size, const int32_t * offsets, const char * strings, int num_merges, const int32_t * merges)

ck_true_bpe_lookup Forward
int32_t ck_true_bpe_lookup(const CKTrueBPE * bpe, const char * token)

ck_true_bpe_num_merges Forward
int32_t ck_true_bpe_num_merges(const CKTrueBPE * bpe)

ck_true_bpe_set_config Forward
void ck_true_bpe_set_config(CKTrueBPE * bpe, const CKBPEConfig * config)

ck_true_bpe_set_special_ids Forward
void ck_true_bpe_set_special_ids(CKTrueBPE * bpe, int32_t unk, int32_t bos, int32_t eos, int32_t pad)

ck_true_bpe_vocab_size Forward
size_t ck_true_bpe_vocab_size(const CKTrueBPE * bpe)

ck_utf8_byte_to_offset Forward
size_t ck_utf8_byte_to_offset(const char * str, size_t len, size_t byte_offset)

Get the character index from byte offset.

ck_utf8_char_length Forward
int ck_utf8_char_length(unsigned char c)

Get the length of a UTF-8 character from its first byte.

ck_utf8_count_chars Forward
size_t ck_utf8_count_chars(const char * str, size_t len)

Count UTF-8 characters in a string.

ck_utf8_decode_2 Forward
uint32_t ck_utf8_decode_2(const char * s)

Get 2-byte UTF-8 sequence value.

ck_utf8_decode_3 Forward
uint32_t ck_utf8_decode_3(const char * s)

Get 3-byte UTF-8 sequence value.

ck_utf8_decode_4 Forward
uint32_t ck_utf8_decode_4(const char * s)

Get 4-byte UTF-8 sequence value.

ck_utf8_first_byte Forward
unsigned char ck_utf8_first_byte(const char * s)

Get the first byte of a UTF-8 character.

ck_utf8_from_cp Forward
int ck_utf8_from_cp(uint32_t cp, char * out)

Write a Unicode code point as UTF-8.

ck_utf8_is_continuation Forward
int ck_utf8_is_continuation(unsigned char c)

Get the continuation byte mask and value.

ck_utf8_is_valid Forward
bool ck_utf8_is_valid(const char * str, size_t len)

Check if a byte sequence is valid UTF-8.

ck_utf8_is_whitespace Forward
bool ck_utf8_is_whitespace(uint32_t cp)

Check if character is whitespace (Unicode White_Space property).

ck_utf8_next_char Forward
int32_t ck_utf8_next_char(const char ** str, int * out_len)

Get next UTF-8 character, return its code point.

ck_utf8_normalize_nfc Forward
size_t ck_utf8_normalize_nfc(const char * src, size_t src_len, char * dst, size_t dst_size)

Normalize UTF-8 string (Unicode normalization form NFC).

ck_utf8_offset_to_byte Forward
size_t ck_utf8_offset_to_byte(const char * str, size_t len, size_t n)

Get the byte offset of the N-th character.

ck_utf8_validate Forward
size_t ck_utf8_validate(const char * str, size_t len)

Validate a UTF-8 string.

ck_weight_dtype_expr Forward
const char * ck_weight_dtype_expr(const CKBufferSpec * spec)

ckernel_backend_native
CKMathBackend ckernel_backend_native(void)

Obtain the built-in native backend (single-node CPU, C + intrinsics).

clamp_int8 Forward
int8_t clamp_int8(float value)

compute_align Forward
CKV2AlignInfo compute_align(const CKModelConfig * cfg)

compute_rms_scale Forward
float compute_rms_scale(const float * x, int n, float eps)

compute_rms_scale_internal Forward
float compute_rms_scale_internal(const float * x, int n, float eps)

convert_bf16_tensor_to_buf Forward
void convert_bf16_tensor_to_buf(const uint16_t * src, float * dst, size_t count)

convert_f16_to_f32 Forward
void convert_f16_to_f32(float * dst, const uint16_t * src, size_t count)

Convert FP16 tensor to FP32.

convert_f32_to_f16 Forward
void convert_f32_to_f16(uint16_t * dst, const float * src, size_t count)

Convert FP32 tensor to FP16.

convert_float_to_int4 Forward
void convert_float_to_int4(const float * src, uint8_t * dst, size_t count)

convert_float_to_int8 Forward
void convert_float_to_int8(const float * src, int8_t * dst, size_t count)

convert_int4_to_float Forward
void convert_int4_to_float(const uint8_t * src, float * dst, size_t count)

convert_int8_to_float Forward
void convert_int8_to_float(const int8_t * src, float * dst, size_t count)

count_set_bits Forward
int count_set_bits(const char * hex_mask)

cpu_features_init Forward
void cpu_features_init(void)

create_entry Forward
CKTokenizerHashEntry * create_entry(const char * key, const void * value, size_t value_size)

create_node Forward
CKTrieNode * create_node(void)

decode_bpe_token Forward
int decode_bpe_token(const char * token, char * out, int max)

Decode GPT-2 byte-level BPE representation back to actual bytes.

decode_int4 Forward
int8_t decode_int4(uint8_t packed, int index)

decode_layer_parallel Forward
void decode_layer_parallel(float * hidden, const void * ln1_weight, const void * ln2_weight, const void * WQ, const void * WK, const void * WV, const void * WO, const void * W_gate, const void * W_up, const void * W_down, float * k_cache, float * v_cache, int token_index, float * scratch, int embed_dim, int intermediate, int H, int H_kv, int head_dim, int max_seq, float eps, int num_threads)

Process one transformer layer in parallel.

dequant_q4_0_block Forward
void dequant_q4_0_block(const block_q4_0 * block, float * output)

Dequantize a single Q4_0 block to FP32.

dequant_q4_0_row Forward
void dequant_q4_0_row(const void * src, float * dst, size_t n_elements)

Dequantize Q4_0 row (multiple blocks)

dequant_q4_1_block Forward
void dequant_q4_1_block(const block_q4_1 * block, float * output)

Dequantize a single Q4_1 block to FP32.

dequant_q4_1_row Forward
void dequant_q4_1_row(const void * src, float * dst, size_t n_elements)

Dequantize Q4_1 row (multiple blocks)

dequant_q4_k_block Forward
void dequant_q4_k_block(const block_q4_K * block, float * output)

Dequantize a single Q4_K block to FP32.

dequant_q4_k_row Forward
void dequant_q4_k_row(const void * src, float * dst, size_t n_elements)

Dequantize Q4_K row (multiple blocks)

dequant_q5_0_block Forward
void dequant_q5_0_block(const block_q5_0 * block, float * output)

Dequantize a single Q5_0 block to FP32.

dequant_q5_0_row Forward
void dequant_q5_0_row(const void * src, float * dst, size_t n_elements)

Dequantize Q5_0 row (multiple blocks)

dequant_q5_1_block Forward
void dequant_q5_1_block(const block_q5_1 * block, float * output)

Dequantize a single Q5_1 block to FP32.

dequant_q5_1_row Forward
void dequant_q5_1_row(const void * src, float * dst, size_t n_elements)

Dequantize Q5_1 row (multiple blocks)

dequant_q6_k_block Forward
void dequant_q6_k_block(const block_q6_K * block, float * output)

Dequantize a single Q6_K block to FP32.

dequant_q6_k_row Forward
void dequant_q6_k_row(const void * src, float * dst, size_t n_elements)

Dequantize Q6_K row (multiple blocks)

dequant_q8_0_block Forward
void dequant_q8_0_block(const block_q8_0 * block, float * output)

Dequantize a single Q8_0 block to FP32.

dequant_q8_0_row Forward
void dequant_q8_0_row(const void * src, float * dst, size_t n_elements)

Dequantize Q8_0 row (multiple blocks)

dequant_row Forward
void dequant_row(CKDataType dtype, const void * src, float * dst, size_t n_elements)

Dequantize a row of quantized data to FP32.

detach_allocation Forward
ck_huge_alloc_entry_t * detach_allocation(void * ptr)

detect_chat_template Forward
ChatTemplateType detect_chat_template(const char * model_name)

detect_physical_cores Forward
int detect_physical_cores(void)

dot_f16 Forward
float dot_f16(const uint16_t * w_f16, const float * x, int K)

dot_fp32_q5_0_block Forward
float dot_fp32_q5_0_block(const float * x, const block_q5_0 * block)

Compute dot product of FP32 input with Q5_0 weight block, with online Q8 quantization.

dot_fp32_q8_0_block Forward
float dot_fp32_q8_0_block(const float * x, const block_q8_0 * block)

Compute dot product of FP32 input with Q8_0 weight block, with online Q8 quantization.

dot_q4_0 Forward
float dot_q4_0(const void * w_q4_0, const float * x, int K)

dot_q4_1 Forward
float dot_q4_1(const void * w_q4_1, const float * x, int K)

dot_q4_k Forward
float dot_q4_k(const void * w_q4k, const float * x, int K)

Compute dot product of Q4_K row with FP32 vector.

dot_q4_k_q8_k_ref Forward
float dot_q4_k_q8_k_ref(const block_q4_K * w, const block_q8_K * x, int k)

dot_q5_0 Forward
float dot_q5_0(const void * w_q5_0, const float * x, int K)

dot_q5_0_q8_k_32_sse Forward
float dot_q5_0_q8_k_32_sse(const block_q5_0 * bw, const block_q8_K * ba, int q8_offset)

dot_q5_1 Forward
float dot_q5_1(const void * w_q5_1, const float * x, int K)

dot_q6_k_q8_k_256_sse Forward
float dot_q6_k_q8_k_256_sse(const block_q6_K * bw, const block_q8_K * ba)

SSE Optimized dot product for Q6_K x Q8_K Q6_K layout: ql: 128 bytes (low 4 bits) qh: 64 bytes (high 2 bits) scales: 16 bytes (int8 scales) d: fp16 super-scale

dot_q6_k_q8_k_ref Forward
float dot_q6_k_q8_k_ref(const block_q6_K * w, const block_q8_K * x, int K)

Scalar dot product for Q6_K x Q8_K.

dot_q6_k_ref Forward
float dot_q6_k_ref(const block_q6_K * w, const float * x, int K)

dot_q8_0 Forward
float dot_q8_0(const void * w_q8_0, const float * x, int K)

embedding_backward Backward
void embedding_backward(const int32_t * token_ids, int token_count, const float * d_output, float * d_token_embeddings, float * d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)

Backward pass / gradient computation

embedding_backward_bf16 Backward
void embedding_backward_bf16(const int32_t * token_ids, int token_count, const uint16_t * d_output, uint16_t * d_token_embeddings, uint16_t * d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)

Backward pass / gradient computation

embedding_forward Forward
void embedding_forward(const int32_t * token_ids, int token_count, int vocab_size, const float * token_embeddings, const float * pos_embeddings, float * output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)

Forward pass computation

embedding_forward_bf16 Forward
void embedding_forward_bf16(const int32_t * token_ids, int token_count, int vocab_size, const uint16_t * token_embeddings, const uint16_t * pos_embeddings, uint16_t * output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)

Forward pass computation

embedding_forward_q4_k Forward
void embedding_forward_q4_k(const int32_t * token_ids, int token_count, int vocab_size, const void * token_embeddings, const float * pos_embeddings, float * output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)

Forward pass computation

embedding_forward_q6_k Forward
void embedding_forward_q6_k(const int32_t * token_ids, int token_count, int vocab_size, const void * token_embeddings, const float * pos_embeddings, float * output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)

Forward pass computation

embedding_forward_q8_0 Forward
void embedding_forward_q8_0(const int32_t * token_ids, int token_count, int vocab_size, const void * token_embeddings, const float * pos_embeddings, float * output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)

Forward pass computation

emit_body_fields Forward
void emit_body_fields(FILE * out, const CKIRV2Graph * graph, CKBufferRole role_filter, int activation_group)

emit_body_values Forward
size_t emit_body_values(FILE * out, const CKIRV2Graph * graph, const CKMemPlan * plan, CKBufferRole role_filter, int activation_group)

emit_bump_bytes_assignment Forward
void emit_bump_bytes_assignment(FILE * out, const char * indent, const char * struct_prefix, const char * name, const CKDimToken * shape)

emit_bump_bytes_assignment_weight_dtype Forward
void emit_bump_bytes_assignment_weight_dtype(FILE * out, const char * indent, const char * struct_prefix, const char * name, const CKDimToken * shape, const char * dtype_expr)

emit_dim_expr Forward
void emit_dim_expr(FILE * out, CKDimKind dim)

emit_footer_fields Forward
void emit_footer_fields(FILE * out, const CKIRV2Graph * graph, CKBufferRole role_filter, int activation_group)

emit_footer_values Forward
void emit_footer_values(FILE * out, const CKIRV2Graph * graph, const CKMemPlan * plan, CKBufferRole role_filter, int activation_group, size_t * offset)

emit_global_aliases_to_layer Forward
void emit_global_aliases_to_layer(FILE * out)

emit_global_allocations Forward
void emit_global_allocations(FILE * out)

emit_global_offset_fields Forward
void emit_global_offset_fields(FILE * out)

emit_header_fields Forward
void emit_header_fields(FILE * out, const CKIRV2Graph * graph, CKBufferRole role_filter, int activation_group)

emit_header_values Forward
void emit_header_values(FILE * out, const CKIRV2Graph * graph, const CKMemPlan * plan, CKBufferRole role_filter, int activation_group, size_t * offset)

emit_kernel_manifest Forward
int emit_kernel_manifest(const CKIRGraph * forward, const char * runtime_path)

emit_layer_allocations Forward
void emit_layer_allocations(FILE * out)

emit_layer_offsets_struct Forward
void emit_layer_offsets_struct(FILE * out)

emit_library_api Forward
void emit_library_api(FILE * out, const CKIRGraph * forward)

emit_model_struct Forward
void emit_model_struct(FILE * out)

emit_offset_field Forward
void emit_offset_field(FILE * out, const char * name)

emit_plan_sources Forward
int emit_plan_sources(FILE * f, const CKPlanStep * plan, size_t plan_count, const CKIRGraph * cfg, const char ** seen, size_t * seen_count, size_t seen_cap)

emit_runtime_preamble Forward
int emit_runtime_preamble(FILE * out)

emit_schedule_block Forward
void emit_schedule_block(FILE * out, const CKIRV2Graph * graph, const char * func_name, const char * label, const char * runtime_sym)

emit_sgd_update Forward
void emit_sgd_update(FILE * out)

emit_shape_expr Forward
void emit_shape_expr(FILE * out, const CKDimToken * shape)

emit_span_field Forward
void emit_span_field(FILE * out, const char * label)

emit_span_value Forward
void emit_span_value(FILE * out, const char * label, size_t offset, size_t size, int comma)

emit_training_conditional_assignment Forward
void emit_training_conditional_assignment(FILE * out, const char * indent, const char * struct_prefix, const char * name, const CKDimToken * shape)

emit_unique_source Forward
int emit_unique_source(FILE * f, const char * path, const char ** seen, size_t * seen_count, size_t seen_cap)

emit_zero_grad Forward
void emit_zero_grad(FILE * out)

encode_chunk Forward
int encode_chunk(CKTrueBPE * bpe, const char * chunk, int chunk_len, int32_t * ids, int max_ids, CKBPETokenList * list)

encode_int4_nibble Forward
uint8_t encode_int4_nibble(int8_t value)

encode_text_segment Forward
int encode_text_segment(CKTrueBPE * bpe, const char * text, int text_len, int32_t * ids, int max_ids)

engine_thread_func Forward
void * engine_thread_func(void * arg)

eos_is_potential_prefix Forward
bool eos_is_potential_prefix(const char * token)

Check if token might be start of EOS pattern.

eos_pattern_init Forward
void eos_pattern_init(ChatTemplateType tmpl)

eos_pattern_process Forward
bool eos_pattern_process(const char * token_text, char * out_buf, size_t * out_len, void(*)(char *, size_t *, const char *) output_fn, ChatTemplateType tmpl)

Process a token for EOS pattern detection.

eos_pattern_reset Forward
void eos_pattern_reset(void)

find_best_merge Forward
int find_best_merge(const CKTrueBPE * bpe, const CKBPETokenList * list, size_t * best_pos, const CKBPEMerge ** best_merge)

find_buffer_by_name Forward
int find_buffer_by_name(const CKIRV2Graph * graph, const char * name)

find_longest_match Forward
int32_t find_longest_match(const CKTokenizer * tok, const char * text, size_t text_len, size_t pos, size_t * match_len)

find_longest_match_hash Forward
int32_t find_longest_match_hash(const CKTokenizer * tok, const char * text, size_t text_len, size_t pos, size_t * match_len)

find_longest_match_trie Forward
int32_t find_longest_match_trie(const CKTokenizer * tok, const char * text, size_t text_len, size_t pos, size_t * match_len)

find_model_in_cache
bool find_model_in_cache(const char * model_name, char * lib_out, char * weights_out, size_t out_size)

find_object_range Forward
int find_object_range(const char * json, const char * key, const char ** out_start, size_t * out_len)

flatten_head_major Forward
void flatten_head_major(const float * attn_out, float * dst, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)

float_tensor_to_bf16 Forward
void float_tensor_to_bf16(const float * src, uint16_t * dst, size_t count)

float_to_bf16 Forward
uint16_t float_to_bf16(float f)

format_bandwidth Forward
const char * format_bandwidth(float bw_gbs, char * buf, size_t buf_size)

format_size Forward
const char * format_size(uint64_t size_mb, char * buf, size_t buf_size)

free_entry Forward
void free_entry(CKTokenizerHashEntry * entry, bool free_value)

fused_kernels_compute_kv_tile Forward
int fused_kernels_compute_kv_tile(int l1_size, int head_dim, int bytes_per_elem)

Compute optimal KV tile size for flash attention.

fused_kernels_report_stats Forward
void fused_kernels_report_stats(int hidden, int num_layers, int seq_len)

Report memory savings from mega-fusion.

fused_kernels_validate_constraints Forward
int fused_kernels_validate_constraints(int l1_size, int head_dim, int kv_tile_size, int bytes_per_elem)

Validate cache constraints for fusion.

fused_output_projection_residual Forward
void fused_output_projection_residual(float * output, const float * o_all, const float * W_o, const float * b_o, const float * residual, int hidden, int num_heads, int head_dim)

Fused output projection with residual add.

geglu_backward_fp32 Backward
void geglu_backward_fp32(const float * x, const float * d_out, float * d_x, int n_tokens, int dim)

GeGLU backward pass (fp32) Testtest_geglu.py::TestGeGLU::test_geglu_backward_fp32

geglu_forward_bf16 Forward
void geglu_forward_bf16(const uint16_t * x, uint16_t * out, int tokens, int dim, float * scratch)

GeGLU forward pass (bf16) Testtest_geglu.py::TestGeGLU::test_geglu_forward_bf16

geglu_forward_fp32 Forward
void geglu_forward_fp32(const float * x, float * out, int tokens, int dim)

GeGLU forward pass (fp32) Testtest_geglu.py::TestGeGLU::test_geglu_forward_fp32

gemv_f16 Forward
void gemv_f16(float * y, const uint16_t * W, const float * x, int M, int K)

Auto-dispatch GEMV based on available SIMD.

gemv_f16_backward Backward
void gemv_f16_backward(float * dX, const uint16_t * W, const float * dY, int M, int K)

Auto-dispatch backward.

gemv_f16_backward_ref Backward
void gemv_f16_backward_ref(float * dX, const uint16_t * W, const float * dY, int M, int K)

Backward pass: compute input gradient (scalar reference)

gemv_f16_ref Forward
void gemv_f16_ref(float * y, const uint16_t * W, const float * x, int M, int K)

Matrix-vector multiply with FP16 weights (scalar reference)

gemv_fused_q5_0_bias Forward
void gemv_fused_q5_0_bias(float * y, const void * W, const float * x, const float * bias, int M, int K)

gemv_fused_q5_0_bias_dispatch Forward
void gemv_fused_q5_0_bias_dispatch(float * y, const void * W, const float * x, const float * bias, int M, int K)

gemv_fused_q5_0_bias_parallel_omp Forward
void gemv_fused_q5_0_bias_parallel_omp(float * y, const void * W, const float * x, const float * bias, int M, int K)

gemv_fused_q8_0_bias Forward
void gemv_fused_q8_0_bias(float * y, const void * W, const float * x, const float * bias, int M, int K)

gemv_fused_q8_0_bias_dispatch Forward
void gemv_fused_q8_0_bias_dispatch(float * y, const void * W, const float * x, const float * bias, int M, int K)

gemv_nt_q5_0_head_major_output Forward
void gemv_nt_q5_0_head_major_output(float * output, const float * attn_out, const void * wo, const float * bias, int tokens, int embed_dim, int num_heads, int head_dim)

Output projection reading head-major attention output (Q5_0 weights)

gemv_q4_0 Forward
void gemv_q4_0(float * y, const void * W, const float * x, int M, int K)

Auto-dispatch GEMV.

gemv_q4_0_backward Backward
void gemv_q4_0_backward(float * dX, const void * W, const float * dY, int M, int K)

Auto-dispatch backward.

gemv_q4_0_backward_ref Backward
void gemv_q4_0_backward_ref(float * dX, const void * W, const float * dY, int M, int K)

Backward pass: compute input gradient.

gemv_q4_0_ref Forward
void gemv_q4_0_ref(float * y, const void * W, const float * x, int M, int K)

Matrix-vector multiply with Q4_0 weights (scalar reference)

gemv_q4_1 Forward
void gemv_q4_1(float * y, const void * W, const float * x, int M, int K)

Auto-dispatch GEMV.

gemv_q4_1_backward Backward
void gemv_q4_1_backward(float * dX, const void * W, const float * dY, int M, int K)

Auto-dispatch backward.

gemv_q4_1_backward_ref Backward
void gemv_q4_1_backward_ref(float * dX, const void * W, const float * dY, int M, int K)

Backward pass: compute input gradient.

gemv_q4_1_ref Forward
void gemv_q4_1_ref(float * y, const void * W, const float * x, int M, int K)

Matrix-vector multiply with Q4_1 weights (scalar reference)

gemv_q4_k Forward
void gemv_q4_k(float * y, const void * W, const float * x, int M, int K)

Auto-dispatch GEMV based on available SIMD.

gemv_q4_k_backward Backward
void gemv_q4_k_backward(float * dX, const void * W, const float * dY, int M, int K)

Auto-dispatch backward.

gemv_q4_k_backward_ref Backward
void gemv_q4_k_backward_ref(float * dX, const void * W, const float * dY, int M, int K)

Backward pass: compute input gradient (scalar reference)

gemv_q4_k_q8_k Forward
void gemv_q4_k_q8_k(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q4_k_q8_k_amx Forward
void gemv_q4_k_q8_k_amx(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q4_k_q8_k_avx Forward
void gemv_q4_k_q8_k_avx(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q4_k_q8_k_avx2 Forward
void gemv_q4_k_q8_k_avx2(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q4_k_q8_k_parallel Forward
void gemv_q4_k_q8_k_parallel(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)

gemv_q4_k_q8_k_parallel_simd Forward
void gemv_q4_k_q8_k_parallel_simd(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)

gemv_q4_k_q8_k_ref Forward
void gemv_q4_k_q8_k_ref(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q4_k_q8_k_sse Forward
void gemv_q4_k_q8_k_sse(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q4_k_q8_k_vnni Forward
void gemv_q4_k_q8_k_vnni(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q4_k_ref Forward
void gemv_q4_k_ref(float * y, const void * W, const float * x, int M, int K)

Matrix-vector multiply with Q4_K weights (scalar reference)

gemv_q5_0 Forward
void gemv_q5_0(float * y, const void * W, const float * x, int M, int K)

Auto-dispatch GEMV for Q5_0 weights based on CPU features.

gemv_q5_0_backward Backward
void gemv_q5_0_backward(float * dX, const void * W, const float * dY, int M, int K)

Auto-dispatch backward.

gemv_q5_0_backward_ref Backward
void gemv_q5_0_backward_ref(float * dX, const void * W, const float * dY, int M, int K)

Backward pass: compute input gradient.

gemv_q5_0_from_fp32 Forward
void gemv_q5_0_from_fp32(float * out, const void * W_q5_0, const float * x_fp32, const float * bias, int M, int K, block_q8_0 * x_q8_scratch)

gemv_q5_0_parallel Forward
void gemv_q5_0_parallel(float * y, const void * W, const float * x, int M, int K, int ith, int nth)

Parallel reference GEMV for Q5_0 × FP32.

gemv_q5_0_parallel_simd Forward
void gemv_q5_0_parallel_simd(float * y, const void * W, const float * x, int M, int K, int ith, int nth)

Parallel SIMD GEMV for Q5_0 × FP32 with prefetching.

gemv_q5_0_q8_0 Forward
void gemv_q5_0_q8_0(float * y, const void * W, const void * x_q8, int M, int K)

Matrix-vector multiply with Q5_0 weights and Q8_0 input.

gemv_q5_0_q8_0_parallel_omp Forward
void gemv_q5_0_q8_0_parallel_omp(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q5_0_q8_0_parallel_simd Forward
void gemv_q5_0_q8_0_parallel_simd(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)

Parallel SIMD GEMV for Q5_0 x Q8_0 with prefetching.

gemv_q5_0_ref Forward
void gemv_q5_0_ref(float * y, const void * W, const float * x, int M, int K)

Matrix-vector multiply with Q5_0 weights (scalar reference)

gemv_q5_1 Forward
void gemv_q5_1(float * y, const void * W, const float * x, int M, int K)

Auto-dispatch GEMV.

gemv_q5_1_backward Backward
void gemv_q5_1_backward(float * dX, const void * W, const float * dY, int M, int K)

Auto-dispatch backward.

gemv_q5_1_backward_ref Backward
void gemv_q5_1_backward_ref(float * dX, const void * W, const float * dY, int M, int K)

Backward pass: compute input gradient.

gemv_q5_1_ref Forward
void gemv_q5_1_ref(float * y, const void * W, const float * x, int M, int K)

Matrix-vector multiply with Q5_1 weights (scalar reference)

gemv_q5_k Forward
void gemv_q5_k(float * y, const void * W, const float * x, int M, int K)

gemv_q5_k_ref Forward
void gemv_q5_k_ref(float * y, const void * W, const float * x, int M, int K)

gemv_q6_k Forward
void gemv_q6_k(float * y, const void * W, const float * x, int M, int K)

gemv_q6_k_q8_k Forward
void gemv_q6_k_q8_k(float * y, const void * W, const void * x_q8, int M, int K)

GEMV: y = W @ x where W is Q6_K and x is Q8_K.

gemv_q6_k_q8_k_avx Forward
void gemv_q6_k_q8_k_avx(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q6_k_q8_k_avx2 Forward
void gemv_q6_k_q8_k_avx2(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q6_k_q8_k_avx512 Forward
void gemv_q6_k_q8_k_avx512(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q6_k_q8_k_avx512_vbmi Forward
void gemv_q6_k_q8_k_avx512_vbmi(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q6_k_q8_k_parallel Forward
void gemv_q6_k_q8_k_parallel(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)

Parallel reference GEMV for Q6_K × Q8_K.

gemv_q6_k_q8_k_parallel_simd Forward
void gemv_q6_k_q8_k_parallel_simd(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)

Parallel SIMD GEMV for Q6_K × Q8_K.

gemv_q6_k_q8_k_ref Forward
void gemv_q6_k_q8_k_ref(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q6_k_q8_k_sse Forward
void gemv_q6_k_q8_k_sse(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q8_0 Forward
void gemv_q8_0(float * y, const void * W, const float * x, int M, int K)

Auto-dispatch GEMV for Q8_0 weights based on CPU features.

gemv_q8_0_backward Backward
void gemv_q8_0_backward(float * dX, const void * W, const float * dY, int M, int K)

Auto-dispatch backward.

gemv_q8_0_backward_ref Backward
void gemv_q8_0_backward_ref(float * dX, const void * W, const float * dY, int M, int K)

Backward pass: compute input gradient (scalar reference)

gemv_q8_0_from_fp32 Forward
void gemv_q8_0_from_fp32(float * out, const void * W_q8_0, const float * x_fp32, const float * bias, int M, int K, block_q8_0 * x_q8_scratch)

gemv_q8_0_parallel_simd Forward
void gemv_q8_0_parallel_simd(float * y, const void * W, const float * x, int M, int K, int ith, int nth)

Parallel SIMD GEMV for Q8_0 weights x FP32 input with prefetching.

gemv_q8_0_q8_0 Forward
void gemv_q8_0_q8_0(float * y, const void * W, const void * x_q8, int M, int K)

Matrix-vector multiply with Q8_0 weights and Q8_0 input.

gemv_q8_0_q8_0_parallel Forward
void gemv_q8_0_q8_0_parallel(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)

Parallel reference GEMV for Q8_0 x Q8_0.

gemv_q8_0_q8_0_parallel_omp Forward
void gemv_q8_0_q8_0_parallel_omp(float * y, const void * W, const void * x_q8, int M, int K)

gemv_q8_0_q8_0_parallel_simd Forward
void gemv_q8_0_q8_0_parallel_simd(float * y, const void * W, const void * x_q8, int M, int K, int ith, int nth)

Parallel SIMD GEMV for Q8_0 x Q8_0 with prefetching.

gemv_q8_0_ref Forward
void gemv_q8_0_ref(float * y, const void * W, const float * x, int M, int K)

Matrix-vector multiply with Q8_0 weights (scalar reference)

get_cache_dir
const char * get_cache_dir(void)

get_cpu_info Forward
const CPUInfo * get_cpu_info(void)

get_current_cpu Forward
int get_current_cpu(void)

get_numa_node_for_cpu Forward
int get_numa_node_for_cpu(int cpu)

get_optimal_decode_threads Forward
int get_optimal_decode_threads(void)

get_q5_k_scale_min Forward
void get_q5_k_scale_min(int j, const uint8_t * scales, uint8_t * scale, uint8_t * min)

get_time_ms Forward
double get_time_ms(void)

global_pool_init Forward
void global_pool_init(void)

gpt2_decode_byte Forward
int gpt2_decode_byte(const unsigned char * s, int len)

gpt2_pretokenize Forward
int gpt2_pretokenize(const char * text, int text_len, PretokChunk * chunks, int max_chunks)

gradient_accumulate_bf16 Forward
void gradient_accumulate_bf16(uint16_t * dst, const uint16_t * src, size_t numel)

Accumulate gradients: dst += src (bf16)

gradient_accumulate_f32 Forward
void gradient_accumulate_f32(float * dst, const float * src, size_t numel)

Accumulate gradients: dst += src (fp32)

gradient_clip_norm_bf16 Forward
float gradient_clip_norm_bf16(uint16_t * grad, size_t numel, float max_norm)

Clip gradient norm (bf16)

gradient_clip_norm_f32 Forward
float gradient_clip_norm_f32(float * grad, size_t numel, float max_norm)

Clip gradient norm (fp32)

gradient_scale_bf16 Forward
void gradient_scale_bf16(uint16_t * grad, size_t numel, float scale)

Scale gradients: grad *= scale (bf16)

gradient_scale_f32 Forward
void gradient_scale_f32(float * grad, size_t numel, float scale)

Scale gradients by a constant: grad *= scale (fp32)

handle_sigint Forward
void handle_sigint(int sig)

has_cpu_flag Forward
int has_cpu_flag(const char * flags, const char * flag)

hash_pair Forward
uint32_t hash_pair(int32_t left, int32_t right)

hash_string Forward
uint32_t hash_string(const char * s, int len)

hmax256_ps_fused Forward
float hmax256_ps_fused(__m256 v)

hsum256_ps_fused Forward
float hsum256_ps_fused(__m256 v)

hsum_epi32_sse Forward
int32_t hsum_epi32_sse(__m128i v)

im2patch Forward
void im2patch(const float * image, float * patches, int C, int H, int W, int P)

im2patch: Transforms an image into a sequence of flattened patches.

im2patch_bf16 Forward
void im2patch_bf16(const uint16_t * image, uint16_t * patches, int C, int H, int W, int P)

init_tokens_from_text Forward
int init_tokens_from_text(CKTrueBPE * bpe, CKBPETokenList * list, const char * text, int text_len)

is_activation_role Forward
int is_activation_role(CKBufferRole role)

is_bpe_digit Forward
bool is_bpe_digit(const char * s, int len)

is_bpe_letter Forward
bool is_bpe_letter(const char * s, int len)

is_bpe_newline Forward
bool is_bpe_newline(const char * s, int len)

is_bpe_punct Forward
bool is_bpe_punct(const char * s, int len)

is_digit Forward
bool is_digit(unsigned char c)

is_eos_token Forward
bool is_eos_token(const CLIOptions * opt, int token)

is_footer_global Forward
int is_footer_global(const char * name)

is_gpt2_space Forward
bool is_gpt2_space(const char * s, int len)

is_letter Forward
bool is_letter(unsigned char c)

is_whitespace Forward
bool is_whitespace(unsigned char c)

is_word_prefix_char Forward
bool is_word_prefix_char(const char * s, int len)

json_match_char Forward
int json_match_char(JSONParser * p, char c)

json_parse_int Forward
int json_parse_int(JSONParser * p, int * out)

json_parse_string Forward
int json_parse_string(JSONParser * p, char * buf, int max_len)

json_skip_value Forward
void json_skip_value(JSONParser * p)

json_skip_whitespace Forward
void json_skip_whitespace(JSONParser * p)

kv_cache_repack_head_major_inplace
void kv_cache_repack_head_major_inplace(float * buf, int num_heads, int tokens, int cache_capacity, int aligned_head_dim)

kv_cache_store
void kv_cache_store(float *__restrict kv_cache_k, float *__restrict kv_cache_v, const float *__restrict k, const float *__restrict v, int layer, int pos, int num_kv_heads, int head_dim, int max_seq_len)

kv_cache_write_head_major
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)

layout_transformer_from_ir Forward
void layout_transformer_from_ir(TransformerModel * m, const CKIRGraph * ir)

Compute a simple forward-only layout for TransformerModel based on: CKModelConfig (dims, heads, vocab, context) The IR graph structure (number of layers, op types)

list_available_models Forward
void list_available_models(void)

load_eos_from_vocab_json Forward
bool load_eos_from_vocab_json(const char * weights_path, CLIOptions * opt)

load_manifest Forward
int load_manifest(const char * path, ManifestEntry ** entries, int * num_entries)

load_model_api Forward
bool load_model_api(const char * lib_path, ModelAPI * api)

load_weights Forward
int load_weights(QWEN2_DECODEModel * model, const char * bump_path, const char * manifest_path)

load_weights_from_bump Forward
int load_weights_from_bump(void * model, const char * bump_path)

logits_copy_to_position Forward
void logits_copy_to_position(const float *__restrict src, float *__restrict dst, int position, int vocab_size)

Copy logits to position-indexed location in output buffer.

main Forward
int main(int argc, char ** argv)

map_forward_to_backward Backward
CKOpType map_forward_to_backward(CKOpType op)

Backward pass / gradient computation

match_special_token Forward
int match_special_token(const CKTrueBPE * bpe, const char * text, int text_len, int pos)

max_k_for_query Forward
int max_k_for_query(int t_q, int T_q, int T_k)

mega_fuse_get_optimal_tiles Forward
void mega_fuse_get_optimal_tiles(int * q_tile, int * kv_tile, int head_dim)

Get optimal tile sizes for current CPU.

mega_fuse_output_proj_residual Forward
void mega_fuse_output_proj_residual(const float * attn_token, const float * wo, const float * bo, const float * residual, float * output, int embed_dim, int aligned_embed_dim, int num_heads, int head_dim, int aligned_head_dim)

mega_fuse_report_stats Forward
void mega_fuse_report_stats(int hidden, int num_layers, int seq_len)

Report memory savings from mega-fusion.

merge_hash Forward
size_t merge_hash(uint64_t key, size_t num_buckets)

merge_key Forward
uint64_t merge_key(int32_t left_id, int32_t right_id)

merge_table_create Forward
CKMergeTable * merge_table_create(size_t num_buckets)

merge_table_free Forward
void merge_table_free(CKMergeTable * table)

merge_table_insert Forward
int merge_table_insert(CKMergeTable * table, const CKBPEMerge * merge)

merge_table_lookup Forward
const CKBPEMerge * merge_table_lookup(const CKMergeTable * table, int32_t left_id, int32_t right_id)

model_align_elems Forward
int model_align_elems(int elems, int elem_bytes, int align_bytes)

model_decode Forward
void model_decode(MODELModel * model, const int * token, int token_index)

model_decode_token Forward
void model_decode_token(MODELModel * model, const int * token, int token_index)

model_forward Forward
void model_forward(MODELModel * model, const int * tokens, int num_tokens)

Forward pass computation

model_forward_prefill_impl Forward
void model_forward_prefill_impl(MODELModel * model, const int * tokens, int num_tokens)

Forward pass computation

model_layer_0_decode Forward
void model_layer_0_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_0_prefill Forward
void model_layer_0_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_10_decode Forward
void model_layer_10_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_10_prefill Forward
void model_layer_10_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_11_decode Forward
void model_layer_11_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_11_prefill Forward
void model_layer_11_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_12_decode Forward
void model_layer_12_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_12_prefill Forward
void model_layer_12_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_13_decode Forward
void model_layer_13_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_13_prefill Forward
void model_layer_13_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_14_decode Forward
void model_layer_14_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_14_prefill Forward
void model_layer_14_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_15_decode Forward
void model_layer_15_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_15_prefill Forward
void model_layer_15_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_16_decode Forward
void model_layer_16_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_16_prefill Forward
void model_layer_16_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_17_decode Forward
void model_layer_17_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_17_prefill Forward
void model_layer_17_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_18_decode Forward
void model_layer_18_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_18_prefill Forward
void model_layer_18_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_19_decode Forward
void model_layer_19_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_19_prefill Forward
void model_layer_19_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_1_decode Forward
void model_layer_1_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_1_prefill Forward
void model_layer_1_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_20_decode Forward
void model_layer_20_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_20_prefill Forward
void model_layer_20_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_21_decode Forward
void model_layer_21_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_21_prefill Forward
void model_layer_21_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_22_decode Forward
void model_layer_22_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_22_prefill Forward
void model_layer_22_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_23_decode Forward
void model_layer_23_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_23_prefill Forward
void model_layer_23_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_2_decode Forward
void model_layer_2_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_2_prefill Forward
void model_layer_2_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_3_decode Forward
void model_layer_3_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_3_prefill Forward
void model_layer_3_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_4_decode Forward
void model_layer_4_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_4_prefill Forward
void model_layer_4_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_5_decode Forward
void model_layer_5_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_5_prefill Forward
void model_layer_5_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_6_decode Forward
void model_layer_6_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_6_prefill Forward
void model_layer_6_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_7_decode Forward
void model_layer_7_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_7_prefill Forward
void model_layer_7_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_8_decode Forward
void model_layer_8_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_8_prefill Forward
void model_layer_8_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_9_decode Forward
void model_layer_9_decode(MODELModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_layer_9_prefill Forward
void model_layer_9_prefill(MODELModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

model_model_allocate Forward
int model_model_allocate(MODELModel * model)

model_model_free Forward
void model_model_free(MODELModel * model)

model_residual_add_token_major Forward
void model_residual_add_token_major(const float * a, const float * b, float * out, int tokens, int aligned_embed_dim)

model_verify_canaries Forward
int model_verify_canaries(MODELModel * model)

moe_accumulate_expert_f32 Forward
void moe_accumulate_expert_f32(float * output, const float * expert_output, float routing_weight, int hidden_dim)

Accumulate expert output: output += routing_weight * expert_output.

op_name Forward
const char * op_name(CKOpType op)

out_proj_head_major_q5_0_q8_0 Forward
void out_proj_head_major_q5_0_q8_0(const uint8_t * attn_q8, const void * wo, const float * bias, float * output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)

out_proj_head_major_q8_0_q8_0 Forward
void out_proj_head_major_q8_0_q8_0(const uint8_t * attn_q8, const void * wo, const float * bias, float * output, int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim)

output_append Forward
void output_append(char * buf, size_t * len, const char * text)

output_flush Forward
void output_flush(char * buf, size_t * len)

output_token Forward
void output_token(char * buf, size_t * len, const char * token)

pack_a_panel Forward
void pack_a_panel(const float * A, int lda, float * Ap, int mc, int kc, int mr)

pack_b_panel Forward
void pack_b_panel(const float * B, int ldb, float * Bp, int kc, int nc, int nr)

parse_args Forward
bool parse_args(int argc, char ** argv, CLIOptions * opt)

parse_eos_ids Forward
bool parse_eos_ids(const char * arg, CLIOptions * opt)

parse_float_field_any Forward
int parse_float_field_any(const char * json, size_t len, const char *const * keys, float * out_value)

parse_float_field_in_range Forward
int parse_float_field_in_range(const char * json, size_t len, const char * key, float * out_value)

parse_int_field Forward
int parse_int_field(const char * json, const char * key, int * out_value)

parse_int_field_any Forward
int parse_int_field_any(const char * json, size_t len, const char *const * keys, int * out_value)

parse_int_field_in_range Forward
int parse_int_field_in_range(const char * json, size_t len, const char * key, int * out_value)

parse_manifest_entry Forward
bool parse_manifest_entry(const char * json, const char * name, size_t * offset, size_t * size)

parse_manifest_int Forward
int parse_manifest_int(const char * json, const char * key)

parse_op Forward
CKOpType parse_op(const char * s)

parse_u64 Forward
unsigned long long parse_u64(const char * s)

patch2im Forward
void patch2im(const float * d_patches, float * d_image, int C, int H, int W, int P)

patch2im: Accumulates gradients from patches back into the image. (Backward pass)

patch2im_bf16 Forward
void patch2im_bf16(const uint16_t * d_patches, uint16_t * d_image, int C, int H, int W, int P)

pcie_bandwidth_gbs Forward
float pcie_bandwidth_gbs(int gen, int width)

plan_size Forward
size_t plan_size(const CKMemPlan * plan, int idx)

pool_new_block Forward
CKPoolBlock * pool_new_block(size_t capacity)

preprocess_bpe_spaces Forward
int preprocess_bpe_spaces(const char * text, int text_len, char * out, int out_max, CKSpacePrefixStyle style)

preprocess_spm_llama_text Forward
int preprocess_spm_llama_text(const char * text, int text_len, char * out, int out_max, bool add_space_prefix)

preprocess_spm_text Forward
int preprocess_spm_text(const char * text, int text_len, char * out, int out_max, bool add_space_prefix)

preprocess_text Forward
int preprocess_text(const CKTrueBPE * bpe, const char * text, int text_len, char * out, int out_max)

print_banner Forward
void print_banner(void)

print_cpu_info Forward
void print_cpu_info(void)

print_header Forward
void print_header(const char * title)

print_help Forward
void print_help(const char * prog)

print_ok Forward
void print_ok(const char * msg)

print_progress Forward
void print_progress(int token_id, float token_per_sec)

print_section Forward
void print_section(const char * title)

print_tree_item Forward
void print_tree_item(int level, int is_last, const char * fmt, ...)

print_usage Forward
void print_usage(const char * prog)

print_version Forward
void print_version(void)

print_warning Forward
void print_warning(const char * msg)

process_repl_command Forward
bool process_repl_command(const char * line, CLIOptions * opt, ModelAPI * api)

qk_norm_forward Forward
void qk_norm_forward(float * q, float * k, const float * q_gamma, const float * k_gamma, int num_heads, int num_kv_heads, int num_tokens, int head_dim, float eps)

Per-head RMSNorm on Q and K.

qkv_index Forward
size_t qkv_index(int h, int t, int d, int num_tokens, int aligned_head_dim)

qkv_projection_parallel Forward
void qkv_projection_parallel(const void * ln1_q8, const void * WQ, const void * WK, const void * WV, float * q_out, float * k_out, float * v_out, int H, int H_kv, int head_dim, int embed_dim, int num_threads)

Parallel Q/K/V projection for single token decode.

qkv_q8_0_dtype_supported Forward
int qkv_q8_0_dtype_supported(CKDataType dt)

qkv_q8_k_dtype_supported Forward
int qkv_q8_k_dtype_supported(CKDataType dt)

quantize_attn_out_head_major_q8_0 Forward
void quantize_attn_out_head_major_q8_0(const float * attn_out, uint8_t * dst, int tokens, int num_heads, int aligned_head_dim)

quantize_batch_q8_0 Forward
void quantize_batch_q8_0(const float * x, void * vy, int num_rows, int k)

Batch quantize FP32 to Q8_0 format (row-major output)

quantize_batch_q8_k Forward
void quantize_batch_q8_k(const float * x, void * vy, int num_rows, int k)

Batch quantize FP32 to Q8_K format (row-major output)

quantize_row_q8_0 Forward
void quantize_row_q8_0(const float * x, void * vy, int k)

Quantize FP32 to Q8_0 format (scalar reference)

quantize_row_q8_k Forward
void quantize_row_q8_k(const float * x, void * vy, int k)

quantize_row_q8_k_ref Forward
void quantize_row_q8_k_ref(const float * x, void * vy, int k)

quantize_row_q8_k_sse Forward
void quantize_row_q8_k_sse(const float * x, void * vy, int k)

qwen2_0_5b_decode_align_elems Forward
int qwen2_0_5b_decode_align_elems(int elems, int elem_bytes, int align_bytes)

qwen2_0_5b_decode_decode Forward
void qwen2_0_5b_decode_decode(QWEN2_0_5B_DECODEModel * model, const int * token, int token_index)

qwen2_0_5b_decode_decode_token Forward
void qwen2_0_5b_decode_decode_token(QWEN2_0_5B_DECODEModel * model, const int * token, int token_index)

qwen2_0_5b_decode_forward Forward
void qwen2_0_5b_decode_forward(QWEN2_0_5B_DECODEModel * model, const int * tokens, int num_tokens)

Forward pass computation

qwen2_0_5b_decode_forward_prefill_impl Forward
void qwen2_0_5b_decode_forward_prefill_impl(QWEN2_0_5B_DECODEModel * model, const int * tokens, int num_tokens)

Forward pass computation

qwen2_0_5b_decode_layer_0_decode Forward
void qwen2_0_5b_decode_layer_0_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_0_prefill Forward
void qwen2_0_5b_decode_layer_0_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_10_decode Forward
void qwen2_0_5b_decode_layer_10_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_10_prefill Forward
void qwen2_0_5b_decode_layer_10_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_11_decode Forward
void qwen2_0_5b_decode_layer_11_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_11_prefill Forward
void qwen2_0_5b_decode_layer_11_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_12_decode Forward
void qwen2_0_5b_decode_layer_12_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_12_prefill Forward
void qwen2_0_5b_decode_layer_12_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_13_decode Forward
void qwen2_0_5b_decode_layer_13_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_13_prefill Forward
void qwen2_0_5b_decode_layer_13_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_14_decode Forward
void qwen2_0_5b_decode_layer_14_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_14_prefill Forward
void qwen2_0_5b_decode_layer_14_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_15_decode Forward
void qwen2_0_5b_decode_layer_15_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_15_prefill Forward
void qwen2_0_5b_decode_layer_15_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_16_decode Forward
void qwen2_0_5b_decode_layer_16_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_16_prefill Forward
void qwen2_0_5b_decode_layer_16_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_17_decode Forward
void qwen2_0_5b_decode_layer_17_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_17_prefill Forward
void qwen2_0_5b_decode_layer_17_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_18_decode Forward
void qwen2_0_5b_decode_layer_18_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_18_prefill Forward
void qwen2_0_5b_decode_layer_18_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_19_decode Forward
void qwen2_0_5b_decode_layer_19_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_19_prefill Forward
void qwen2_0_5b_decode_layer_19_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_1_decode Forward
void qwen2_0_5b_decode_layer_1_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_1_prefill Forward
void qwen2_0_5b_decode_layer_1_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_20_decode Forward
void qwen2_0_5b_decode_layer_20_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_20_prefill Forward
void qwen2_0_5b_decode_layer_20_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_21_decode Forward
void qwen2_0_5b_decode_layer_21_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_21_prefill Forward
void qwen2_0_5b_decode_layer_21_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_22_decode Forward
void qwen2_0_5b_decode_layer_22_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_22_prefill Forward
void qwen2_0_5b_decode_layer_22_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_23_decode Forward
void qwen2_0_5b_decode_layer_23_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_23_prefill Forward
void qwen2_0_5b_decode_layer_23_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_2_decode Forward
void qwen2_0_5b_decode_layer_2_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_2_prefill Forward
void qwen2_0_5b_decode_layer_2_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_3_decode Forward
void qwen2_0_5b_decode_layer_3_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_3_prefill Forward
void qwen2_0_5b_decode_layer_3_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_4_decode Forward
void qwen2_0_5b_decode_layer_4_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_4_prefill Forward
void qwen2_0_5b_decode_layer_4_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_5_decode Forward
void qwen2_0_5b_decode_layer_5_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_5_prefill Forward
void qwen2_0_5b_decode_layer_5_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_6_decode Forward
void qwen2_0_5b_decode_layer_6_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_6_prefill Forward
void qwen2_0_5b_decode_layer_6_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_7_decode Forward
void qwen2_0_5b_decode_layer_7_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_7_prefill Forward
void qwen2_0_5b_decode_layer_7_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_8_decode Forward
void qwen2_0_5b_decode_layer_8_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_8_prefill Forward
void qwen2_0_5b_decode_layer_8_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_9_decode Forward
void qwen2_0_5b_decode_layer_9_decode(QWEN2_0_5B_DECODEModel * model, int token_index, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_layer_9_prefill Forward
void qwen2_0_5b_decode_layer_9_prefill(QWEN2_0_5B_DECODEModel * model, int num_tokens, int aligned_embed_dim, int aligned_head_dim, int aligned_intermediate_dim, int aligned_context_window)

qwen2_0_5b_decode_model_allocate Forward
int qwen2_0_5b_decode_model_allocate(QWEN2_0_5B_DECODEModel * model)

qwen2_0_5b_decode_model_free Forward
void qwen2_0_5b_decode_model_free(QWEN2_0_5B_DECODEModel * model)

qwen2_0_5b_decode_residual_add_token_major Forward
void qwen2_0_5b_decode_residual_add_token_major(const float * a, const float * b, float * out, int tokens, int aligned_embed_dim)

qwen2_0_5b_decode_verify_canaries Forward
int qwen2_0_5b_decode_verify_canaries(QWEN2_0_5B_DECODEModel * model)

read_file_int Forward
int read_file_int(const char * path)

read_file_string Forward
int read_file_string(const char * path, char * buf, size_t buf_size)

read_file_uint64 Forward
uint64_t read_file_uint64(const char * path)

read_floats Forward
int read_floats(FILE * f, float * dst, size_t count)

read_manifest_entry Forward
bool read_manifest_entry(const char * json_path, const char * entry_name, size_t * out_offset, size_t * out_size)

read_prompt_file Forward
char * read_prompt_file(const char * path)

record_allocation Forward
int record_allocation(void * ptr, size_t len, int was_mmap)

relu_backward Backward
void relu_backward(const float * input, const float * d_output, float * d_input, size_t n)

Backward pass / gradient computation

relu_backward_bf16 Backward
void relu_backward_bf16(const uint16_t * input, const uint16_t * d_output, uint16_t * d_input, size_t n)

Backward pass / gradient computation

relu_forward Forward
void relu_forward(const float * input, float * output, size_t n)

Forward pass computation

relu_forward_bf16 Forward
void relu_forward_bf16(const uint16_t * input, uint16_t * output, size_t n)

Forward pass computation

relu_forward_inplace Forward
void relu_forward_inplace(float * data, size_t n)

Forward pass computation

relu_forward_inplace_bf16 Forward
void relu_forward_inplace_bf16(uint16_t * data, size_t n)

Forward pass computation

residual_add Forward
void residual_add(float * residual, float * addend, int n)

residual_add_parallel Forward
void residual_add_parallel(const float * a, const float * b, float * out, int n, int ith, int nth)

Single-token decode with parallel SIMD kernels.

resolve_dim Forward
size_t resolve_dim(const CKModelConfig * cfg, const CKIRV2AlignInfo * align, CKDimKind kind, int tokens_override)

resolve_shape_elems Forward
size_t resolve_shape_elems(const CKModelConfig * cfg, const CKIRV2AlignInfo * align, const CKDimToken * shape, int tokens_override)

resolve_symbol Forward
bool resolve_symbol(void * handle, const char * name, void ** out_ptr, bool required)

run_benchmark Forward
void run_benchmark(void * model, int num_tokens)

run_command Forward
int run_command(const char * cmd, char * output, size_t output_size)

run_generation_test Forward
void run_generation_test(void * model, int num_tokens)

run_inference Forward
int run_inference(const char * bump_path, const char * manifest_path, const char * tokenizer_path, const char * prompt, int max_tokens, float temperature, int topk)

run_prompt Forward
int run_prompt(ModelAPI * api, CKTrueBPE * tokenizer, CLIOptions * opt, const char * input)

sample_argmax Forward
int sample_argmax(const float * logits, int vocab_size)

sample_token Forward
int sample_token(float * logits, int vocab_size, float temp, int top_k)

sample_top_p Forward
int sample_top_p(float * logits, int vocab_size, float temperature, float top_p)

sample_topk Forward
int sample_topk(float * probs, int vocab_size, int topk)

scal_copy_f32 Forward
void scal_copy_f32(float * y, const float * x, float alpha, int n)

Scaled copy: y = alpha * x.

score_index Forward
size_t score_index(int h, int i, int j, int aligned_context_window)

sgd_momentum_update_bf16 Forward
void sgd_momentum_update_bf16(const uint16_t * grad, uint16_t * weight, float * velocity, size_t numel, float lr, float momentum, float weight_decay)

SGD with momentum (bf16 weights/gradients)

sgd_momentum_update_f32 Forward
void sgd_momentum_update_f32(const float * grad, float * weight, float * velocity, size_t numel, float lr, float momentum, float weight_decay)

SGD with momentum optimizer update (fp32 version)

silu Forward
void silu(float * x, int n)

silu_prefill Forward
float silu_prefill(float x)

silu_scalar Forward
float silu_scalar(float x)

simd_strcmp Forward
int simd_strcmp(const char * s1, const char * s2)

simple_embedding Forward
void simple_embedding(const int32_t * tokens, int num_tokens, const float * weight, float * output, int vocab_size, int embed_dim)

spm_build_byte_lookup Forward
void spm_build_byte_lookup(CKTokenizer * tok, const char * strings, const int32_t * offsets, int vocab_size)

spm_count_unknown_run Forward
int spm_count_unknown_run(const CKTokenizer * tok, const char * text, int text_len, size_t pos)

spm_encode_byte_fallback Forward
int spm_encode_byte_fallback(const CKTokenizer * tok, const char * text, int text_len, int32_t * ids, int max_ids)

spm_find_candidates_at_pos Forward
int spm_find_candidates_at_pos(const CKTokenizer * tok, const char * text, int text_len, size_t pos, int32_t * candidates, int max_candidates)

spm_get_byte_token Forward
int32_t spm_get_byte_token(const CKTokenizer * tok, unsigned char byte_val)

spm_is_byte_token Forward
bool spm_is_byte_token(const CKTokenizer * tok, int32_t token_id)

spm_llama_resegment_node Forward
int spm_llama_resegment_node(const CKTokenizer * tok, const SpmLlamaNode * nodes, int node_id, int32_t * ids, int max_ids, int out_idx)

spm_token_allowed_in_dp Forward
bool spm_token_allowed_in_dp(const CKTokenizer * tok, int32_t token_id)

spm_token_is_byte_format Forward
bool spm_token_is_byte_format(const char * token)

starts_with Forward
int starts_with(const char * s, const char * prefix)

token_list_append Forward
int token_list_append(CKBPETokenList * list, const char * str, size_t len, int32_t id)

token_list_clear Forward
void token_list_clear(CKBPETokenList * list)

token_list_create Forward
CKBPETokenList * token_list_create(size_t initial_capacity)

token_list_free Forward
void token_list_free(CKBPETokenList * list)

token_list_merge_at Forward
int token_list_merge_at(CKBPETokenList * list, size_t pos, const char * merged_str, size_t merged_len, int32_t merged_id)

tokenize Forward
int32_t * tokenize(const char * text, int * num_tokens)

topk_batched_f32 Forward
void topk_batched_f32(const float * scores, int num_tokens, int n_experts, int k, int * indices, float * weights)

Batched top-K selection for multiple tokens.

topk_f32 Forward
void topk_f32(const float * scores, int n, int k, int * indices, float * values)

Find top-K indices and values from a score vector.

topology_discover Forward
int topology_discover(SystemTopology * topo)

topology_discover_affinity Forward
int topology_discover_affinity(AffinityInfo * aff)

topology_discover_cache
int topology_discover_cache(CacheTopology * cache)

topology_discover_cpu Forward
int topology_discover_cpu(CPUInfo * cpu)

topology_discover_memory Forward
int topology_discover_memory(MemoryInfo * mem)

topology_discover_network Forward
int topology_discover_network(NetworkTopology * net)

topology_discover_numa Forward
int topology_discover_numa(NUMATopology * numa)

topology_discover_pcie Forward
int topology_discover_pcie(PCIeTopology * pcie)

topology_estimate_channels_from_bandwidth Forward
int topology_estimate_channels_from_bandwidth(float measured_bw_gbs, int memory_speed_mhz, const char * memory_type)

topology_estimate_memory_bandwidth Forward
float topology_estimate_memory_bandwidth(const MemoryInfo * mem)

topology_estimate_network_training_time Forward
float topology_estimate_network_training_time(const NetworkTopology * net, uint64_t model_size_mb)

topology_generate_recommendations Forward
int topology_generate_recommendations(const SystemTopology * topo, RecommendationList * recs)

topology_measure_memory_bandwidth Forward
float topology_measure_memory_bandwidth(void)

topology_measure_memory_bandwidth_ex Forward
float topology_measure_memory_bandwidth_ex(int * numa_node_out, int * num_threads_out)

topology_print_affinity Forward
void topology_print_affinity(const AffinityInfo * aff)

topology_print_cache
void topology_print_cache(const CacheTopology * cache, int logical_cores)

topology_print_cpu Forward
void topology_print_cpu(const CPUInfo * cpu)

topology_print_distributed_potential Forward
void topology_print_distributed_potential(const SystemTopology * topo)

topology_print_memory Forward
void topology_print_memory(const MemoryInfo * mem)

topology_print_network Forward
void topology_print_network(const NetworkTopology * net)

topology_print_numa Forward
void topology_print_numa(const NUMATopology * numa, int sockets)

topology_print_pcie Forward
void topology_print_pcie(const PCIeTopology * pcie)

topology_print_recommendations Forward
void topology_print_recommendations(const RecommendationList * recs)

topology_print_summary Forward
void topology_print_summary(const SystemTopology * topo)

trim_string Forward
void trim_string(char * str)

unpack_q4_k_scales Forward
void unpack_q4_k_scales(const uint8_t * scales, uint8_t * sc, uint8_t * m)

Unpack Q4_K sub-block scales and mins.

unpack_q5_k_scales Forward
void unpack_q5_k_scales(const uint8_t * scales, uint8_t * sc, uint8_t * m)

Unpack Q5_K sub-block scales and mins.

utf8_char_len Forward
int utf8_char_len(unsigned char c)

utf8_len Forward
int utf8_len(unsigned char c)

v6_prefill Forward
void v6_prefill(const float * embed_weight, const int32_t * tokens, int num_tokens, float * logits)

vec_dot_q5_0_q8_0 Forward
void vec_dot_q5_0_q8_0(int n, float * s, const void * vx, const void * vy)

Auto-dispatch quantized dot product Q5_0 x Q8_0.

vec_dot_q5_0_q8_0_ref Forward
void vec_dot_q5_0_q8_0_ref(int n, float * s, const void * vx, const void * vy)

Quantized dot product: Q5_0 weights x Q8_0 input (scalar reference)

vec_dot_q6_k_q8_k Forward
void vec_dot_q6_k_q8_k(int n, float * s, const void * vx, const void * vy)

Q6_K x Q8_K dot product (single row)

vec_dot_q8_0_q8_0 Forward
void vec_dot_q8_0_q8_0(int n, float * s, const void * vx, const void * vy)

Auto-dispatch quantized dot product Q8_0 x Q8_0.

vec_dot_q8_0_q8_0_ref Forward
void vec_dot_q8_0_q8_0_ref(int n, float * s, const void * vx, const void * vy)

Quantized dot product: Q8_0 weights x Q8_0 input (scalar reference)

vec_scale_parallel Forward
void vec_scale_parallel(float * y, float scale, int n, int ith, int nth)

vec_zero_parallel Forward
void vec_zero_parallel(float * y, int n, int ith, int nth)

weighted_sum_f32 Forward
void weighted_sum_f32(float * y, const float ** vectors, const float * weights, int k, int n)

Weighted sum of k vectors: y = sum_i(weights[i] * vectors[i])

worker_main Forward
void * worker_main(void * arg)

zero_gradients_bf16 Forward
void zero_gradients_bf16(uint16_t * grad, size_t numel)

Zero out gradient buffer (bf16)

zero_gradients_f32 Forward
void zero_gradients_f32(float * grad, size_t numel)

Zero out gradient buffer (fp32)

Usage Example

Forward + Backward Pass

// Forward pass
rmsnorm_forward(input, gamma, norm_out, rstd_cache, tokens, d_model, d_model, eps);
attention_forward_causal_head_major_gqa(q, k, v, scores, attn_out, heads, kv_heads, tokens, head_dim, head_dim, ctx_len);
swiglu_forward(mlp_in, mlp_out, tokens, hidden_dim);

// Backward pass (reverse order)
swiglu_backward(mlp_in, d_mlp_out, d_mlp_in, tokens, hidden_dim);
attention_backward_causal_head_major_gqa(d_attn_out, q, k, v, scores, d_q, d_k, d_v, d_scores, heads, kv_heads, tokens, head_dim, head_dim, ctx_len);
rmsnorm_backward(d_norm_out, input, gamma, rstd_cache, d_input, d_gamma, tokens, d_model, d_model);
Image
100% | |
Scroll to zoom | Drag to pan | W/H to fit | 0 to reset | ESC to close