C-Kernel-Engine Parity Testing API Implementation. More...
#include "ck_parity_api.h"#include "ckernel_quant.h"#include <math.h>#include <stdlib.h>#include <string.h>Go to the source code of this file.
Functions | |
| void | attention_forward_causal_head_major_gqa_flash_strided (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens) |
| void | attention_forward_causal_head_major_gqa_flash_strided_sliding (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window) |
| void | attention_forward_decode_head_major_gqa_flash_sliding (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window) |
| int | ck_get_block_q4_k_size (void) |
| Get Q4_K block size in bytes. More... | |
| int | ck_get_block_q5_1_size (void) |
| Get Q5_1 block size in bytes (24 bytes per 32 weights) More... | |
| int | ck_get_block_q5_k_size (void) |
| Get Q5_K block size in bytes (176 bytes per 256 weights) More... | |
| int | ck_get_block_q6_k_size (void) |
| Get Q6_K block size in bytes. More... | |
| int | ck_get_block_q8_k_size (void) |
| Get Q8_K block size in bytes. More... | |
| int | ck_get_qk5_1 (void) |
| Get QK5_1 (elements per Q5_1 block) More... | |
| int | ck_get_qk_k (void) |
| Get QK_K (elements per super-block) More... | |
| void | ck_test_attention_causal (const float *q, const float *k, const float *v, float *out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim) |
| Multi-head causal attention for prefill (head-major layout) More... | |
| void | ck_test_attention_decode_sliding (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int sliding_window) |
| Test sliding-window attention (decode mode) More... | |
| void | ck_test_attention_sliding_window (const float *q, const float *k, const float *v, float *out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim, int sliding_window) |
| Test sliding-window attention (prefill) More... | |
| void | ck_test_dequant_q4_0 (const void *src, float *dst, int n) |
| Dequantize Q4_0 data to FP32. More... | |
| void | ck_test_dequant_q4_k (const void *src, float *dst, int n) |
| Dequantize Q4_K data to FP32. More... | |
| void | ck_test_dequant_q5_1 (const void *src, float *dst, int n) |
| Dequantize Q5_1 data to FP32. More... | |
| void | ck_test_dequant_q6_k (const void *src, float *dst, int n) |
| Dequantize Q6_K data to FP32. More... | |
| void | ck_test_geglu (const float *x, float *out, int n_tokens, int dim) |
| Test GeGLU activation. More... | |
| void | ck_test_geglu_backward (const float *x, const float *d_out, float *d_x, int n_tokens, int dim) |
| Test GeGLU backward. More... | |
| void | ck_test_gemm_q4_k (const void *weight_q4k, const float *input_f32, float *output, int rows, int cols, int n_tokens) |
| Q4_K GEMM - batched matrix multiply with quantized weights. More... | |
| void | ck_test_gemm_q5_0 (const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols, int n_tokens) |
| Test Q5_0 x Q8_0 GEMM (batch matrix multiply) More... | |
| void | ck_test_gemm_q5_1 (const void *weight_q5_1, const float *input_f32, float *output, int rows, int cols, int n_tokens) |
| Test Q5_1 x Q8_0 GEMM (batch matrix multiply) More... | |
| void | ck_test_gemm_q5_k (const void *weight_q5_k, const float *input_f32, float *output, int rows, int cols, int n_tokens) |
| Test Q5_K x Q8_K GEMM (batch matrix multiply) More... | |
| void | ck_test_gemm_q6_k (const void *weight_q6k, const float *input_f32, float *output, int rows, int cols, int n_tokens) |
| Test Q6_K x Q8_K GEMM (batch matrix multiply) More... | |
| void | ck_test_gemm_q8_0 (const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols, int n_tokens) |
| Test Q8_0 x Q8_0 GEMM (batch matrix multiply) More... | |
| void | ck_test_gemv_q4_k (const void *weight_q4k, const float *input_f32, float *output, int cols) |
| Q4_K GEMV - dot product of quantized weights and FP32 input. More... | |
| void | ck_test_gemv_q5_0 (const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols) |
| Q5_0 GEMV - matrix-vector multiply with Q5_0 weights. More... | |
| void | ck_test_gemv_q5_0_q8_0 (const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols) |
| Q5_0 x Q8_0 quantized GEMV - matches llama.cpp's approach. More... | |
| void | ck_test_gemv_q5_1 (const void *weight_q5_1, const float *input_f32, float *output, int rows, int cols) |
| Q5_1 GEMV - matrix-vector multiply with Q5_1 weights (32-element blocks) More... | |
| void | ck_test_gemv_q5_k (const void *weight_q5_k, const float *input_f32, float *output, int rows, int cols) |
| Q5_K GEMV - matrix-vector multiply with Q5_K weights (256-element super-blocks) More... | |
| void | ck_test_gemv_q6_k (const void *weight_q6k, const float *input_f32, float *output, int cols) |
| Q6_K GEMV. More... | |
| void | ck_test_gemv_q8_0 (const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols) |
| Q8_0 GEMV - matrix-vector multiply with Q8_0 weights. More... | |
| void | ck_test_gemv_q8_0_q8_0 (const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols) |
| Q8_0 x Q8_0 quantized GEMV - matches llama.cpp's approach. More... | |
| void | ck_test_outproj_mlp_fused_q5_0 (const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const void *w1, const void *w2, float *output, int tokens, int num_heads, int head_dim, int embed_dim, int intermediate, float eps, int w2_is_q6k) |
| Test mega-fused OutProj + MLP kernel (Q5_0 weights) More... | |
| void | ck_test_quantize_q8_k (const float *src, void *dst, int n) |
| Quantize FP32 to Q8_K (for activations) More... | |
| void | ck_test_rmsnorm (const float *input, const float *weight, float *output, int n_tokens, int dim, float eps) |
| RMSNorm. More... | |
| void | ck_test_rope (float *q, float *k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta) |
| RoPE (Rotary Position Embedding) More... | |
| void | ck_test_rope_interleaved (float *q, float *k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta) |
| RoPE with interleaved format (for llama.cpp compatibility) More... | |
| void | ck_test_softmax (const float *input, float *output, int n) |
| Softmax (simple, non-causal) More... | |
| void | ck_test_swiglu (const float *gate_up, float *output, int n_tokens, int intermediate_dim) |
| SwiGLU activation. More... | |
| void | ck_test_vec_dot_q5_0_q8_0 (const void *weight_q5_0, const void *input_q8_0, float *output, int cols) |
| Direct Q5_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input) More... | |
| void | ck_test_vec_dot_q8_0_q8_0 (const void *weight_q8_0, const void *input_q8_0, float *output, int cols) |
| Direct Q8_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input) More... | |
| void | dequant_q4_0_row (const void *src, float *dst, size_t n_elements) |
| Dequantize Q4_0 row (multiple blocks) More... | |
| void | dequant_q4_k_row (const void *src, float *dst, size_t n_elements) |
| Dequantize Q4_K row (multiple blocks) More... | |
| void | dequant_q5_1_row (const void *src, float *dst, size_t n_elements) |
| Dequantize Q5_1 row (multiple blocks) More... | |
| void | dequant_q6_k_row (const void *src, float *dst, size_t n_elements) |
| Dequantize Q6_K row (multiple blocks) More... | |
| void | geglu_backward_fp32 (const float *x, const float *d_out, float *d_x, int n_tokens, int dim) |
| void | geglu_forward_fp32 (const float *x, float *out, int tokens, int dim) |
| void | gemm_nt_q4_k_q8_k (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nt_q5_0_q8_0 (const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K) |
| Batch GEMM with Q5_0 weights and Q8_0 activations for prefill. More... | |
| void | gemm_nt_q5_1 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| GEMM with transposed Q5_1 weights: C = A @ B^T. More... | |
| void | gemm_nt_q5_k (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nt_q6_k_q8_k (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K) |
| NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K. More... | |
| void | gemm_nt_q8_0_q8_0 (const void *A_q8, const void *B_q8, const float *bias, float *C, int M, int N, int K) |
| gemm_nt_q8_0_q8_0 with optional bias (matches header signature) More... | |
| void | gemv_q4_k_q8_k (float *y, const void *W, const void *x_q8, int M, int K) |
| void | gemv_q5_0 (float *y, const void *W, const float *x, int M, int K) |
| Auto-dispatch GEMV for Q5_0 weights based on CPU features. More... | |
| void | gemv_q5_0_q8_0 (float *y, const void *W, const void *x_q8, int M, int K) |
| Matrix-vector multiply with Q5_0 weights and Q8_0 input. More... | |
| void | gemv_q5_1 (float *y, const void *W, const float *x, int M, int K) |
| Auto-dispatch GEMV. More... | |
| void | gemv_q5_k (float *y, const void *W, const float *x, int M, int K) |
| void | gemv_q6_k_q8_k (float *y, const void *W, const void *x_q8, int M, int K) |
| GEMV: y = W @ x where W is Q6_K and x is Q8_K. More... | |
| void | gemv_q8_0 (float *y, const void *W, const float *x, int M, int K) |
| Auto-dispatch GEMV for Q8_0 weights based on CPU features. More... | |
| void | gemv_q8_0_q8_0 (float *y, const void *W, const void *x_q8, int M, int K) |
| Matrix-vector multiply with Q8_0 weights and Q8_0 input. More... | |
| void | mega_fused_outproj_mlp_prefill (float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, int wo_dt, const void *w1, const float *b1, int w1_dt, const void *w2, const float *b2, int w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch) |
| size_t | mega_fused_outproj_mlp_prefill_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim) |
| Get scratch buffer size for mega_fused_outproj_mlp_prefill. More... | |
| void | quantize_row_q8_0 (const float *x, void *vy, int k) |
| Quantize FP32 to Q8_0 format (scalar reference) More... | |
| void | quantize_row_q8_k (const float *x, void *vy, int k) |
| void | rmsnorm_forward (const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps) |
| void | rope_forward_qk (float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset) |
| void | rope_precompute_cache (float *cos_cache, float *sin_cache, int max_seq_len, int head_dim, float base) |
| void | swiglu_forward (const float *input, float *output, int tokens, int dim) |
| void | vec_dot_q5_0_q8_0 (int n, float *s, const void *vx, const void *vy) |
| Auto-dispatch quantized dot product Q5_0 x Q8_0. More... | |
| void | vec_dot_q8_0_q8_0 (int n, float *s, const void *vx, const void *vy) |
| Auto-dispatch quantized dot product Q8_0 x Q8_0. More... | |
C-Kernel-Engine Parity Testing API Implementation.
Wraps CK kernels for parity testing against llama.cpp/ggml.
Definition in file ck_parity_api.c.
| void attention_forward_causal_head_major_gqa_flash_strided | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | kv_stride_tokens | ||
| ) |
Flash attention forward with custom KV stride (for KV cache)
test_flash_attention.py::TestFlashAttention::test_flash_strided
test_kv_cache_attention.py::TestKVCacheAttention::test_flash_attention
Variant with configurable kv_stride_tokens for KV cache layouts where K/V may not be contiguous in memory.
After changes: make test
Definition at line 859 of file attention_kernels.c.
Referenced by ck_test_attention_causal().
| void attention_forward_causal_head_major_gqa_flash_strided_sliding | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | output, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | kv_stride_tokens, | ||
| int | sliding_window | ||
| ) |
Flash attention forward with sliding window (prefill)
Sliding-window attention for prefill: each token attends to the last W tokens. When sliding_window <= 0, behaves like regular causal attention.
After changes: make test
Definition at line 1316 of file attention_kernels.c.
Referenced by ck_test_attention_sliding_window().
| void attention_forward_decode_head_major_gqa_flash_sliding | ( | const float * | q_token, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| float * | out_token, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | kv_tokens, | ||
| int | cache_capacity, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | sliding_window | ||
| ) |
Flash attention decode with sliding window
Single query token attends to the last W tokens in the KV cache. For decode: effective_kv_tokens = min(kv_tokens, sliding_window)
After changes: make test
Definition at line 1382 of file attention_kernels.c.
Referenced by ck_test_attention_decode_sliding().
| int ck_get_block_q4_k_size | ( | void | ) |
Get Q4_K block size in bytes.
Definition at line 961 of file ck_parity_api.c.
| int ck_get_block_q5_1_size | ( | void | ) |
Get Q5_1 block size in bytes (24 bytes per 32 weights)
Definition at line 986 of file ck_parity_api.c.
| int ck_get_block_q5_k_size | ( | void | ) |
Get Q5_K block size in bytes (176 bytes per 256 weights)
Definition at line 981 of file ck_parity_api.c.
| int ck_get_block_q6_k_size | ( | void | ) |
Get Q6_K block size in bytes.
Definition at line 966 of file ck_parity_api.c.
| int ck_get_block_q8_k_size | ( | void | ) |
Get Q8_K block size in bytes.
Definition at line 971 of file ck_parity_api.c.
| int ck_get_qk5_1 | ( | void | ) |
Get QK5_1 (elements per Q5_1 block)
Definition at line 991 of file ck_parity_api.c.
References QK5_1.
| int ck_get_qk_k | ( | void | ) |
Get QK_K (elements per super-block)
Definition at line 976 of file ck_parity_api.c.
References QK_K.
| void ck_test_attention_causal | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | out, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | tokens, | ||
| int | seq_len, | ||
| int | head_dim | ||
| ) |
Multi-head causal attention for prefill (head-major layout)
Layout (head-major, matches llama.cpp test): Q: [num_heads, tokens, head_dim] K: [num_kv_heads, seq_len, head_dim] V: [num_kv_heads, seq_len, head_dim] out: [num_heads, tokens, head_dim]
Supports GQA (grouped-query attention) where num_heads > num_kv_heads. Causal masking: token t can only attend to positions 0..t (inclusive).
| q | Query [num_heads, tokens, head_dim] |
| k | Key [num_kv_heads, seq_len, head_dim] |
| v | Value [num_kv_heads, seq_len, head_dim] |
| out | Output [num_heads, tokens, head_dim] |
| num_heads | Number of query heads |
| num_kv_heads | Number of key/value heads (for GQA) |
| tokens | Number of query tokens |
| seq_len | Key/value sequence length (for prefill: seq_len == tokens) |
| head_dim | Dimension per head |
Definition at line 736 of file ck_parity_api.c.
References attention_forward_causal_head_major_gqa_flash_strided().
| void ck_test_attention_decode_sliding | ( | const float * | q_token, |
| const float * | k_cache, | ||
| const float * | v_cache, | ||
| float * | out_token, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | kv_tokens, | ||
| int | cache_capacity, | ||
| int | head_dim, | ||
| int | sliding_window | ||
| ) |
Test sliding-window attention (decode mode)
Single query token attending to KV cache with sliding window.
Definition at line 794 of file ck_parity_api.c.
References attention_forward_decode_head_major_gqa_flash_sliding().
| void ck_test_attention_sliding_window | ( | const float * | q, |
| const float * | k, | ||
| const float * | v, | ||
| float * | out, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | tokens, | ||
| int | seq_len, | ||
| int | head_dim, | ||
| int | sliding_window | ||
| ) |
Test sliding-window attention (prefill)
Layout (head-major, matching CK-Engine): Q: [num_heads, tokens, head_dim] K: [num_kv_heads, seq_len, head_dim] V: [num_kv_heads, seq_len, head_dim] out: [num_heads, tokens, head_dim]
Each token attends only to the last sliding_window tokens.
Definition at line 769 of file ck_parity_api.c.
References attention_forward_causal_head_major_gqa_flash_strided_sliding().
| void ck_test_dequant_q4_0 | ( | const void * | src, |
| float * | dst, | ||
| int | n | ||
| ) |
Dequantize Q4_0 data to FP32.
Definition at line 122 of file ck_parity_api.c.
References dequant_q4_0_row().
| void ck_test_dequant_q4_k | ( | const void * | src, |
| float * | dst, | ||
| int | n | ||
| ) |
Dequantize Q4_K data to FP32.
| src | Input Q4_K blocks |
| dst | Output FP32 values |
| n | Number of elements (must be multiple of 256) |
Definition at line 112 of file ck_parity_api.c.
References dequant_q4_k_row().
| void ck_test_dequant_q5_1 | ( | const void * | src, |
| float * | dst, | ||
| int | n | ||
| ) |
Dequantize Q5_1 data to FP32.
Definition at line 127 of file ck_parity_api.c.
References dequant_q5_1_row().
| void ck_test_dequant_q6_k | ( | const void * | src, |
| float * | dst, | ||
| int | n | ||
| ) |
Dequantize Q6_K data to FP32.
Definition at line 117 of file ck_parity_api.c.
References dequant_q6_k_row().
| void ck_test_geglu | ( | const float * | x, |
| float * | out, | ||
| int | n_tokens, | ||
| int | dim | ||
| ) |
Test GeGLU activation.
Computes: output = GELU(a) * b where input contains [a, b] concatenated along the last dimension.
Definition at line 819 of file ck_parity_api.c.
References geglu_forward_fp32().
| void ck_test_geglu_backward | ( | const float * | x, |
| const float * | d_out, | ||
| float * | d_x, | ||
| int | n_tokens, | ||
| int | dim | ||
| ) |
Test GeGLU backward.
Computes gradients dL/dx given dL/d(out) where out = GELU(a) * b
Definition at line 832 of file ck_parity_api.c.
References geglu_backward_fp32().
| void ck_test_gemm_q4_k | ( | const void * | weight_q4k, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | rows, | ||
| int | cols, | ||
| int | n_tokens | ||
| ) |
Q4_K GEMM - batched matrix multiply with quantized weights.
Computes: output[t,r] = sum_k(weight[r,k] * input[t,k])
| weight_q4k | Q4_K quantized weights [rows, cols] |
| input_f32 | FP32 input [n_tokens, cols] |
| output | FP32 output [n_tokens, rows] |
| rows | Number of output rows |
| cols | Number of columns (must be multiple of 256) |
| n_tokens | Batch size |
Definition at line 392 of file ck_parity_api.c.
References CK_QK_K, gemm_nt_q4_k_q8_k(), and quantize_row_q8_k().
| void ck_test_gemm_q5_0 | ( | const void * | weight_q5_0, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | rows, | ||
| int | cols, | ||
| int | n_tokens | ||
| ) |
Test Q5_0 x Q8_0 GEMM (batch matrix multiply)
Q5_0 GEMM - batched matrix multiply with Q5_0 weights (32-element blocks)
Used for MLP W1 (gate/up projection) and attention Q/K with Q5_0 weights.
Definition at line 491 of file ck_parity_api.c.
References CK_QK8_0, gemm_nt_q5_0_q8_0(), and quantize_row_q8_0().
| void ck_test_gemm_q5_1 | ( | const void * | weight_q5_1, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | rows, | ||
| int | cols, | ||
| int | n_tokens | ||
| ) |
Test Q5_1 x Q8_0 GEMM (batch matrix multiply)
Q5_1 GEMM - batched matrix multiply with Q5_1 weights (32-element blocks)
Used for MLP W1 (gate/up projection) and attention Q/K with Q5_1 weights. gemm_nt_q5_1 expects FP32 activations (not quantized).
Definition at line 542 of file ck_parity_api.c.
References gemm_nt_q5_1().
| void ck_test_gemm_q5_k | ( | const void * | weight_q5_k, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | rows, | ||
| int | cols, | ||
| int | n_tokens | ||
| ) |
Test Q5_K x Q8_K GEMM (batch matrix multiply)
Q5_K GEMM - batched matrix multiply with Q5_K weights (256-element super-blocks)
Used for MLP W1 (gate/up projection) and attention Q/K with Q5_K weights. gemm_nt_q5_k expects FP32 activations (not quantized).
Definition at line 525 of file ck_parity_api.c.
References gemm_nt_q5_k().
| void ck_test_gemm_q6_k | ( | const void * | weight_q6k, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | rows, | ||
| int | cols, | ||
| int | n_tokens | ||
| ) |
Test Q6_K x Q8_K GEMM (batch matrix multiply)
Q6_K GEMM - batched matrix multiply with Q6_K weights.
Used for MLP W2 (down projection) with Q6_K weights.
Definition at line 425 of file ck_parity_api.c.
References CK_QK_K, gemm_nt_q6_k_q8_k(), and quantize_row_q8_k().
| void ck_test_gemm_q8_0 | ( | const void * | weight_q8_0, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | rows, | ||
| int | cols, | ||
| int | n_tokens | ||
| ) |
Test Q8_0 x Q8_0 GEMM (batch matrix multiply)
Q8_0 GEMM - batched matrix multiply with Q8_0 weights (32-element blocks)
Used for attention V projection with Q8_0 weights.
Definition at line 458 of file ck_parity_api.c.
References CK_QK8_0, gemm_nt_q8_0_q8_0(), and quantize_row_q8_0().
| void ck_test_gemv_q4_k | ( | const void * | weight_q4k, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | cols | ||
| ) |
Q4_K GEMV - dot product of quantized weights and FP32 input.
Internally quantizes input to Q8_K, then computes dot product.
| weight_q4k | Q4_K quantized weights [cols] |
| input_f32 | FP32 input vector [cols] |
| output | Output scalar [1] |
| cols | Number of columns (must be multiple of 256) |
Definition at line 145 of file ck_parity_api.c.
References CK_QK_K, gemv_q4_k_q8_k(), and quantize_row_q8_k().
| void ck_test_gemv_q5_0 | ( | const void * | weight_q5_0, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | rows, | ||
| int | cols | ||
| ) |
Q5_0 GEMV - matrix-vector multiply with Q5_0 weights.
| weight_q5_0 | Q5_0 quantized weights [rows * cols] |
| input_f32 | FP32 input vector [cols] |
| output | FP32 output vector [rows] |
| rows | Number of output rows |
| cols | Number of columns (must be multiple of 32) |
Definition at line 192 of file ck_parity_api.c.
References CK_QK8_0, gemv_q5_0_q8_0(), and quantize_row_q8_0().
| void ck_test_gemv_q5_0_q8_0 | ( | const void * | weight_q5_0, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | rows, | ||
| int | cols | ||
| ) |
Q5_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.
This version quantizes the input to Q8_0 first, then uses integer dot products (like llama.cpp does). Use this for parity testing.
| weight_q5_0 | Q5_0 quantized weights [rows * cols] |
| input_f32 | FP32 input vector [cols] - will be quantized to Q8_0 |
| output | FP32 output vector [rows] |
| rows | Number of output rows |
| cols | Number of columns (must be multiple of 32) |
Definition at line 248 of file ck_parity_api.c.
References CK_QK8_0, gemv_q5_0_q8_0(), and quantize_row_q8_0().
| void ck_test_gemv_q5_1 | ( | const void * | weight_q5_1, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | rows, | ||
| int | cols | ||
| ) |
Q5_1 GEMV - matrix-vector multiply with Q5_1 weights (32-element blocks)
Uses Q8_0 for activations (like Q5_0).
| weight_q5_1 | Q5_1 quantized weights [rows * cols] |
| input_f32 | FP32 input vector [cols] |
| output | FP32 output vector [rows] |
| rows | Number of output rows |
| cols | Number of columns (must be multiple of 32) |
Definition at line 333 of file ck_parity_api.c.
References gemv_q5_1(), and QK5_1.
| void ck_test_gemv_q5_k | ( | const void * | weight_q5_k, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | rows, | ||
| int | cols | ||
| ) |
Q5_K GEMV - matrix-vector multiply with Q5_K weights (256-element super-blocks)
Uses Q8_K for activations (like Q4_K).
| weight_q5_k | Q5_K quantized weights [rows * cols] |
| input_f32 | FP32 input vector [cols] |
| output | FP32 output vector [rows] |
| rows | Number of output rows |
| cols | Number of columns (must be multiple of 256) |
Definition at line 301 of file ck_parity_api.c.
References CK_QK_K, and gemv_q5_k().
| void ck_test_gemv_q6_k | ( | const void * | weight_q6k, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | cols | ||
| ) |
| void ck_test_gemv_q8_0 | ( | const void * | weight_q8_0, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | rows, | ||
| int | cols | ||
| ) |
Q8_0 GEMV - matrix-vector multiply with Q8_0 weights.
| weight_q8_0 | Q8_0 quantized weights [rows * cols] |
| input_f32 | FP32 input vector [cols] |
| output | FP32 output vector [rows] |
| rows | Number of output rows |
| cols | Number of columns (must be multiple of 32) |
Definition at line 220 of file ck_parity_api.c.
References CK_QK8_0, gemv_q8_0_q8_0(), and quantize_row_q8_0().
| void ck_test_gemv_q8_0_q8_0 | ( | const void * | weight_q8_0, |
| const float * | input_f32, | ||
| float * | output, | ||
| int | rows, | ||
| int | cols | ||
| ) |
Q8_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.
This version quantizes the input to Q8_0 first, then uses integer dot products (like llama.cpp does). Use this for parity testing.
| weight_q8_0 | Q8_0 quantized weights [rows * cols] |
| input_f32 | FP32 input vector [cols] - will be quantized to Q8_0 |
| output | FP32 output vector [rows] |
| rows | Number of output rows |
| cols | Number of columns (must be multiple of 32) |
Definition at line 274 of file ck_parity_api.c.
References CK_QK8_0, gemv_q8_0_q8_0(), and quantize_row_q8_0().
| void ck_test_outproj_mlp_fused_q5_0 | ( | const float * | attn_out, |
| const float * | residual, | ||
| const float * | ln2_gamma, | ||
| const void * | wo, | ||
| const void * | w1, | ||
| const void * | w2, | ||
| float * | output, | ||
| int | tokens, | ||
| int | num_heads, | ||
| int | head_dim, | ||
| int | embed_dim, | ||
| int | intermediate, | ||
| float | eps, | ||
| int | w2_is_q6k | ||
| ) |
Test mega-fused OutProj + MLP kernel (Q5_0 weights)
This is a simplified wrapper for parity testing that:
| attn_out | Attention output [num_heads, tokens, head_dim] (FP32, head-major) |
| residual | Residual input [tokens, embed_dim] (FP32) |
| ln2_gamma | RMSNorm gamma [embed_dim] (FP32) |
| wo | OutProj weights [embed_dim, embed_dim] (Q5_0) |
| w1 | MLP W1 weights [2*intermediate, embed_dim] (Q5_0) |
| w2 | MLP W2 weights [embed_dim, intermediate] (Q4_K or Q6_K) |
| output | Output [tokens, embed_dim] (FP32) |
| tokens | Number of tokens |
| num_heads | Number of attention heads |
| head_dim | Dimension per head |
| embed_dim | Embedding dimension (= num_heads * head_dim) |
| intermediate | MLP intermediate dimension |
| eps | RMSNorm epsilon |
| w2_is_q6k | If true, W2 is Q6_K; if false, W2 is Q4_K |
Definition at line 894 of file ck_parity_api.c.
References mega_fused_outproj_mlp_prefill(), and mega_fused_outproj_mlp_prefill_scratch_size().
| void ck_test_quantize_q8_k | ( | const float * | src, |
| void * | dst, | ||
| int | n | ||
| ) |
Quantize FP32 to Q8_K (for activations)
| src | Input FP32 values |
| dst | Output Q8_K blocks |
| n | Number of elements (must be multiple of 256) |
Definition at line 136 of file ck_parity_api.c.
References quantize_row_q8_k().
| void ck_test_rmsnorm | ( | const float * | input, |
| const float * | weight, | ||
| float * | output, | ||
| int | n_tokens, | ||
| int | dim, | ||
| float | eps | ||
| ) |
RMSNorm.
Computes: output = (input / rms(input)) * weight where rms(x) = sqrt(mean(x^2) + eps)
| input | Input tensor [n_tokens, dim] |
| weight | Normalization weights [dim] |
| output | Output tensor [n_tokens, dim] |
| n_tokens | Number of tokens |
| dim | Hidden dimension |
| eps | Epsilon for numerical stability |
Definition at line 557 of file ck_parity_api.c.
References rmsnorm_forward().
| void ck_test_rope | ( | float * | q, |
| float * | k, | ||
| int | n_tokens, | ||
| int | n_heads, | ||
| int | n_heads_kv, | ||
| int | head_dim, | ||
| int | pos_offset, | ||
| float | theta | ||
| ) |
RoPE (Rotary Position Embedding)
Applies rotary position embeddings to Q and K tensors.
NOTE: CK uses rotate-half format (split first/second halves) while some implementations use interleaved format. The test harness should account for this.
| q | Query tensor [n_tokens, n_heads * head_dim], modified in-place |
| k | Key tensor [n_tokens, n_heads_kv * head_dim], modified in-place |
| n_tokens | Number of tokens |
| n_heads | Number of query heads |
| n_heads_kv | Number of key/value heads |
| head_dim | Dimension per head |
| pos_offset | Starting position for RoPE |
| theta | RoPE base frequency (typically 10000.0) |
Definition at line 567 of file ck_parity_api.c.
References rope_forward_qk(), and rope_precompute_cache().
| void ck_test_rope_interleaved | ( | float * | q, |
| float * | k, | ||
| int | n_tokens, | ||
| int | n_heads, | ||
| int | n_heads_kv, | ||
| int | head_dim, | ||
| int | pos_offset, | ||
| float | theta | ||
| ) |
RoPE with interleaved format (for llama.cpp compatibility)
Uses interleaved format: (x0, x1) -> (x0*cos - x1*sin, x0*sin + x1*cos)
Definition at line 644 of file ck_parity_api.c.
| void ck_test_softmax | ( | const float * | input, |
| float * | output, | ||
| int | n | ||
| ) |
Softmax (simple, non-causal)
Computes: output[i] = exp(input[i]) / sum(exp(input))
| input | Input tensor [n] |
| output | Output tensor [n] |
| n | Number of elements |
Definition at line 710 of file ck_parity_api.c.
| void ck_test_swiglu | ( | const float * | gate_up, |
| float * | output, | ||
| int | n_tokens, | ||
| int | intermediate_dim | ||
| ) |
SwiGLU activation.
Computes: output = SiLU(gate) * up where SiLU(x) = x * sigmoid(x)
| gate_up | Input tensor [n_tokens, 2 * intermediate_dim] Layout: [gate_0..gate_D-1, up_0..up_D-1] per token |
| output | Output tensor [n_tokens, intermediate_dim] |
| n_tokens | Number of tokens |
| intermediate_dim | Intermediate dimension |
Definition at line 703 of file ck_parity_api.c.
References swiglu_forward().
| void ck_test_vec_dot_q5_0_q8_0 | ( | const void * | weight_q5_0, |
| const void * | input_q8_0, | ||
| float * | output, | ||
| int | cols | ||
| ) |
Direct Q5_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)
Direct Q5_0 x Q8_0 dot product (takes pre-quantized Q8_0 input)
This is a "direct" test that bypasses FP32-to-Q8_0 conversion. Useful for isolating kernel bugs from quantization bugs.
| weight_q5_0 | Q5_0 quantized weights [cols] |
| input_q8_0 | Q8_0 quantized input [cols] (pre-quantized!) |
| output | Output scalar [1] |
| cols | Number of elements (must be multiple of 32) |
Definition at line 364 of file ck_parity_api.c.
References vec_dot_q5_0_q8_0().
| void ck_test_vec_dot_q8_0_q8_0 | ( | const void * | weight_q8_0, |
| const void * | input_q8_0, | ||
| float * | output, | ||
| int | cols | ||
| ) |
Direct Q8_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)
Direct Q8_0 x Q8_0 dot product (takes pre-quantized Q8_0 input)
| weight_q8_0 | Q8_0 quantized weights [cols] |
| input_q8_0 | Q8_0 quantized input [cols] (pre-quantized!) |
| output | Output scalar [1] |
| cols | Number of elements (must be multiple of 32) |
Definition at line 380 of file ck_parity_api.c.
References vec_dot_q8_0_q8_0().
| void dequant_q4_0_row | ( | const void * | src, |
| float * | dst, | ||
| size_t | n_elements | ||
| ) |
Dequantize Q4_0 row (multiple blocks)
| src | Q4_0 data |
| dst | FP32 output |
| n_elements | Number of elements to dequantize |
Definition at line 61 of file dequant_kernels.c.
Referenced by ck_test_dequant_q4_0(), and dequant_row().
| void dequant_q4_k_row | ( | const void * | src, |
| float * | dst, | ||
| size_t | n_elements | ||
| ) |
Dequantize Q4_K row (multiple blocks)
Definition at line 370 of file dequant_kernels.c.
Referenced by ck_test_dequant_q4_k(), and dequant_row().
| void dequant_q5_1_row | ( | const void * | src, |
| float * | dst, | ||
| size_t | n_elements | ||
| ) |
Dequantize Q5_1 row (multiple blocks)
Definition at line 255 of file dequant_kernels.c.
Referenced by ck_test_dequant_q5_1(), and dequant_row().
| void dequant_q6_k_row | ( | const void * | src, |
| float * | dst, | ||
| size_t | n_elements | ||
| ) |
Dequantize Q6_K row (multiple blocks)
Definition at line 420 of file dequant_kernels.c.
Referenced by ck_test_dequant_q6_k(), ck_test_gemv_q6_k(), and dequant_row().
| void geglu_backward_fp32 | ( | const float * | x, |
| const float * | d_out, | ||
| float * | d_x, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
GeGLU backward pass (fp32)
dL/dx given dL/d(out) where out = GELU(a) * b Chain rule: dL/da = dL/dout * d(GELU)/da * b dL/db = dL/dout * GELU(a)
After changes: make test
Definition at line 843 of file gelu_kernels.c.
Referenced by ck_test_geglu_backward().
| void geglu_forward_fp32 | ( | const float * | x, |
| float * | out, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
GeGLU forward pass (fp32)
Computes out = GELU(a) * b where x = [a, b] along last dimension. Input shape: [tokens, 2 * dim], Output shape: [tokens, dim]
After changes: make test
Definition at line 623 of file gelu_kernels.c.
Referenced by ck_test_geglu(), and geglu_forward_bf16().
| void gemm_nt_q4_k_q8_k | ( | const void * | A_q8, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 295 of file gemm_kernels_q4k_q8k.c.
Referenced by ck_test_gemm_q4_k().
| void gemm_nt_q5_0_q8_0 | ( | const void * | A_q8, |
| const void * | B_q5, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.
Computes C = A @ B^T + bias where: A: [M x K] Q8_0 quantized activations (M tokens, K features) B: [N x K] Q5_0 quantized weights (N outputs, K features) C: [M x N] FP32 output
This is the INT8 batch kernel for prefill, using pre-quantized activations to avoid FP32->Q8_0 conversion overhead per operation.
| A_q8 | Input activations in Q8_0 format [M rows of K/32 blocks each] |
| B_q5 | Weights in Q5_0 format [N rows of K/32 blocks each] |
| bias | Optional bias vector [N], NULL if not used |
| C | Output matrix [M x N], row-major FP32 |
| M | Batch size (number of tokens) |
| N | Output dimension (number of output features) |
| K | Input dimension (must be multiple of 32) |
Definition at line 1617 of file gemm_kernels_q5_0.c.
Referenced by ck_test_gemm_q5_0().
| void gemm_nt_q5_1 | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
GEMM with transposed Q5_1 weights: C = A @ B^T.
| A | Input activations [M x K], row-major FP32 |
| B | Weight matrix in Q5_1 format [N x K], row-major quantized |
| bias | Optional bias [N], NULL if not used |
| C | Output [M x N], row-major FP32 |
| M | Batch size (number of tokens) |
| N | Output dimension |
| K | Input dimension |
Definition at line 309 of file gemm_kernels_q5_1.c.
Referenced by ck_test_gemm_q5_1().
| void gemm_nt_q5_k | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 218 of file gemm_kernels_q5_k.c.
Referenced by ck_test_gemm_q5_k().
| void gemm_nt_q6_k_q8_k | ( | const void * | A_q8, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.
This is the typical inference pattern:
| A_q8 | Input activations in Q8_K format |
| B | Weight matrix in Q6_K format |
| bias | Optional bias vector [N] |
| C | Output matrix |
| M | Batch size (number of tokens) |
| N | Output dimension |
| K | Input dimension |
Definition at line 1144 of file gemm_kernels_q6k_q8k.c.
Referenced by ck_test_gemm_q6_k().
| void gemm_nt_q8_0_q8_0 | ( | const void * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
C[m,n] = A[m,K] @ B[n,K]^T + bias[n]
Definition at line 582 of file gemm_batch_int8.c.
Referenced by ck_test_gemm_q8_0().
| void gemv_q4_k_q8_k | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K | ||
| ) |
Definition at line 239 of file gemm_kernels_q4k_q8k.c.
Referenced by ck_test_gemv_q4_k().
| void gemv_q5_0 | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch GEMV for Q5_0 weights based on CPU features.
Dispatch priority (best available):
Uses ck_features.h for standardized feature detection.
| y | Output vector [M] |
| W | Weight matrix in Q5_0 format [M x K] |
| x | Input vector [K] |
| M | Number of output rows |
| K | Number of input columns (hidden dimension) |
Definition at line 547 of file gemm_kernels_q5_0.c.
Referenced by dot_q5_0(), gemm_nt_q5_0(), and gemm_q5_0().
| void gemv_q5_0_q8_0 | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K | ||
| ) |
Matrix-vector multiply with Q5_0 weights and Q8_0 input.
| y | Output vector [M] |
| W | Weight matrix in Q5_0 format [M x K] |
| x_q8 | Input vector in Q8_0 format [K] |
| M | Number of output rows |
| K | Number of columns (must be multiple of 32) |
Definition at line 1529 of file gemm_kernels_q5_0.c.
Referenced by ck_test_gemv_q5_0(), and ck_test_gemv_q5_0_q8_0().
| void gemv_q5_1 | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch GEMV.
Definition at line 184 of file gemm_kernels_q5_1.c.
Referenced by ck_test_gemv_q5_1(), dot_q5_1(), and gemm_q5_1().
| void gemv_q5_k | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Definition at line 199 of file gemm_kernels_q5_k.c.
Referenced by ck_test_gemv_q5_k().
| void gemv_q6_k_q8_k | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K | ||
| ) |
GEMV: y = W @ x where W is Q6_K and x is Q8_K.
Definition at line 980 of file gemm_kernels_q6k_q8k.c.
Referenced by gemm_q6_k_q8_k().
| void gemv_q8_0 | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch GEMV for Q8_0 weights based on CPU features.
Dispatch priority (best available):
Uses ck_features.h for standardized feature detection.
| y | Output vector [M] |
| W | Weight matrix in Q8_0 format [M x K] |
| x | Input vector [K] |
| M | Number of output rows |
| K | Number of input columns (hidden dimension) |
Definition at line 630 of file gemm_kernels_q8_0.c.
Referenced by dot_q8_0(), gemm_nt_q8_0(), and gemm_q8_0().
| void gemv_q8_0_q8_0 | ( | float * | y, |
| const void * | W, | ||
| const void * | x_q8, | ||
| int | M, | ||
| int | K | ||
| ) |
Matrix-vector multiply with Q8_0 weights and Q8_0 input.
| y | Output vector [M] |
| W | Weight matrix in Q8_0 format [M x K] |
| x_q8 | Input vector in Q8_0 format [K] |
| M | Number of output rows |
| K | Number of columns (must be multiple of 32) |
Definition at line 1042 of file gemm_kernels_q8_0.c.
Referenced by ck_test_gemv_q8_0(), and ck_test_gemv_q8_0_q8_0().
| void mega_fused_outproj_mlp_prefill | ( | float * | output, |
| const float * | attn_out, | ||
| const float * | residual, | ||
| const float * | ln2_gamma, | ||
| const void * | wo, | ||
| const float * | bo, | ||
| int | wo_dt, | ||
| const void * | w1, | ||
| const float * | b1, | ||
| int | w1_dt, | ||
| const void * | w2, | ||
| const float * | b2, | ||
| int | w2_dt, | ||
| int | tokens, | ||
| int | embed_dim, | ||
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim, | ||
| int | intermediate_dim, | ||
| int | aligned_intermediate_dim, | ||
| float | eps, | ||
| void * | scratch | ||
| ) |
Referenced by ck_test_outproj_mlp_fused_q5_0().
| size_t mega_fused_outproj_mlp_prefill_scratch_size | ( | int | tokens, |
| int | aligned_embed_dim, | ||
| int | num_heads, | ||
| int | aligned_head_dim, | ||
| int | aligned_intermediate_dim | ||
| ) |
Get scratch buffer size for mega_fused_outproj_mlp_prefill.
Definition at line 159 of file mega_fused_outproj_mlp_prefill.c.
Referenced by ck_test_outproj_mlp_fused_q5_0().
| void quantize_row_q8_0 | ( | const float * | x, |
| void * | vy, | ||
| int | k | ||
| ) |
Quantize FP32 to Q8_0 format (scalar reference)
| x | Input FP32 values |
| vy | Output Q8_0 blocks |
| k | Number of elements (must be multiple of 32) |
Definition at line 59 of file gemm_kernels_q8_0.c.
Referenced by ck_test_gemm_q5_0(), ck_test_gemm_q8_0(), ck_test_gemv_q5_0(), ck_test_gemv_q5_0_q8_0(), ck_test_gemv_q8_0(), and ck_test_gemv_q8_0_q8_0().
| void quantize_row_q8_k | ( | const float * | x, |
| void * | vy, | ||
| int | k | ||
| ) |
Definition at line 107 of file gemm_kernels_q4k_q8k.c.
Referenced by ck_test_gemm_q4_k(), ck_test_gemm_q6_k(), ck_test_gemv_q4_k(), and ck_test_quantize_q8_k().
| void rmsnorm_forward | ( | const float * | input, |
| const float * | gamma, | ||
| float * | output, | ||
| float * | rstd_cache, | ||
| int | tokens, | ||
| int | d_model, | ||
| int | aligned_embed_dim, | ||
| float | eps | ||
| ) |
RMSNorm forward pass
test_rmsnorm.py::TestRMSNormForward::test_fp32_tokens
test_rmsnorm.py::TestRMSNormForward::test_fp32_single
test_rmsnorm.py::TestRMSNormForward::test_perf_rolled
test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat
test_parity.py::test_rmsnorm_parity
RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)
After changes: make test && make llamacpp-parity-full
Definition at line 50 of file rmsnorm_kernels.c.
Referenced by ck_test_rmsnorm().
| void rope_forward_qk | ( | float * | q, |
| float * | k, | ||
| const float * | cos_cache, | ||
| const float * | sin_cache, | ||
| int | num_heads, | ||
| int | num_kv_heads, | ||
| int | num_tokens, | ||
| int | head_dim, | ||
| int | aligned_head_dim, | ||
| int | pos_offset | ||
| ) |
RoPE forward for both Q and K (common inference pattern)
test_rope.py::TestRoPEForward::test_rope_forward_qk
test_fused_attention_decode.py::TestFusedAttentionDecode::test_qk_rope
test_parity.py::test_rope_qk_parity
Combined RoPE forward for both Q and K in one call. q: [num_heads, num_tokens, head_dim] k: [num_kv_heads, num_tokens, head_dim]
After changes: make test && make llamacpp-parity-full
Definition at line 448 of file rope_kernels.c.
Referenced by ck_test_rope().
| void rope_precompute_cache | ( | float * | cos_cache, |
| float * | sin_cache, | ||
| int | max_seq_len, | ||
| int | head_dim, | ||
| float | base | ||
| ) |
Precompute RoPE cos/sin cache
test_rope.py::TestRoPECache::test_cache_computation
test_rope.py::TestRoPECache::test_cache_values
Precomputes cos(m * theta_i) and sin(m * theta_i) for positions 0..max_seq_len-1. cos_cache, sin_cache: [max_seq_len, head_dim/2]
After changes: make test
Definition at line 52 of file rope_kernels.c.
Referenced by ck_test_rope().
| void swiglu_forward | ( | const float * | input, |
| float * | output, | ||
| int | tokens, | ||
| int | dim | ||
| ) |
SwiGLU forward pass
test_swiglu.py::TestSwiGLUForward::test_forward_tokens
test_swiglu.py::TestSwiGLUForward::test_forward_single
test_mlp.py::TestMLPForward::test_swiglu_mlp
test_fused_swiglu_decode.py::TestFusedSwiGLUDecode::test_fused_swiglu_decode
test_parity.py::test_swiglu_parity
SwiGLU: y = silu(gate) * up where silu(x) = x * sigmoid(x)
After changes: make test && make llamacpp-parity-full
Definition at line 131 of file swiglu_kernels.c.
Referenced by ck_test_swiglu().
| void vec_dot_q5_0_q8_0 | ( | int | n, |
| float * | s, | ||
| const void * | vx, | ||
| const void * | vy | ||
| ) |
Auto-dispatch quantized dot product Q5_0 x Q8_0.
Dispatch priority:
Definition at line 1498 of file gemm_kernels_q5_0.c.
Referenced by ck_test_vec_dot_q5_0_q8_0().
| void vec_dot_q8_0_q8_0 | ( | int | n, |
| float * | s, | ||
| const void * | vx, | ||
| const void * | vy | ||
| ) |
Auto-dispatch quantized dot product Q8_0 x Q8_0.
Definition at line 1013 of file gemm_kernels_q8_0.c.
Referenced by ck_test_vec_dot_q8_0_q8_0().