← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ck_parity_api.c File Reference

C-Kernel-Engine Parity Testing API Implementation. More...

#include "ck_parity_api.h"
#include "ckernel_quant.h"
#include <math.h>
#include <stdlib.h>
#include <string.h>

Go to the source code of this file.

Functions

void attention_forward_causal_head_major_gqa_flash_strided (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
 
void attention_forward_causal_head_major_gqa_flash_strided_sliding (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window)
 
void attention_forward_decode_head_major_gqa_flash_sliding (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window)
 
int ck_get_block_q4_k_size (void)
 Get Q4_K block size in bytes. More...
 
int ck_get_block_q5_1_size (void)
 Get Q5_1 block size in bytes (24 bytes per 32 weights) More...
 
int ck_get_block_q5_k_size (void)
 Get Q5_K block size in bytes (176 bytes per 256 weights) More...
 
int ck_get_block_q6_k_size (void)
 Get Q6_K block size in bytes. More...
 
int ck_get_block_q8_k_size (void)
 Get Q8_K block size in bytes. More...
 
int ck_get_qk5_1 (void)
 Get QK5_1 (elements per Q5_1 block) More...
 
int ck_get_qk_k (void)
 Get QK_K (elements per super-block) More...
 
void ck_test_attention_causal (const float *q, const float *k, const float *v, float *out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim)
 Multi-head causal attention for prefill (head-major layout) More...
 
void ck_test_attention_decode_sliding (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int sliding_window)
 Test sliding-window attention (decode mode) More...
 
void ck_test_attention_sliding_window (const float *q, const float *k, const float *v, float *out, int num_heads, int num_kv_heads, int tokens, int seq_len, int head_dim, int sliding_window)
 Test sliding-window attention (prefill) More...
 
void ck_test_dequant_q4_0 (const void *src, float *dst, int n)
 Dequantize Q4_0 data to FP32. More...
 
void ck_test_dequant_q4_k (const void *src, float *dst, int n)
 Dequantize Q4_K data to FP32. More...
 
void ck_test_dequant_q5_1 (const void *src, float *dst, int n)
 Dequantize Q5_1 data to FP32. More...
 
void ck_test_dequant_q6_k (const void *src, float *dst, int n)
 Dequantize Q6_K data to FP32. More...
 
void ck_test_geglu (const float *x, float *out, int n_tokens, int dim)
 Test GeGLU activation. More...
 
void ck_test_geglu_backward (const float *x, const float *d_out, float *d_x, int n_tokens, int dim)
 Test GeGLU backward. More...
 
void ck_test_gemm_q4_k (const void *weight_q4k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
 Q4_K GEMM - batched matrix multiply with quantized weights. More...
 
void ck_test_gemm_q5_0 (const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols, int n_tokens)
 Test Q5_0 x Q8_0 GEMM (batch matrix multiply) More...
 
void ck_test_gemm_q5_1 (const void *weight_q5_1, const float *input_f32, float *output, int rows, int cols, int n_tokens)
 Test Q5_1 x Q8_0 GEMM (batch matrix multiply) More...
 
void ck_test_gemm_q5_k (const void *weight_q5_k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
 Test Q5_K x Q8_K GEMM (batch matrix multiply) More...
 
void ck_test_gemm_q6_k (const void *weight_q6k, const float *input_f32, float *output, int rows, int cols, int n_tokens)
 Test Q6_K x Q8_K GEMM (batch matrix multiply) More...
 
void ck_test_gemm_q8_0 (const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols, int n_tokens)
 Test Q8_0 x Q8_0 GEMM (batch matrix multiply) More...
 
void ck_test_gemv_q4_k (const void *weight_q4k, const float *input_f32, float *output, int cols)
 Q4_K GEMV - dot product of quantized weights and FP32 input. More...
 
void ck_test_gemv_q5_0 (const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols)
 Q5_0 GEMV - matrix-vector multiply with Q5_0 weights. More...
 
void ck_test_gemv_q5_0_q8_0 (const void *weight_q5_0, const float *input_f32, float *output, int rows, int cols)
 Q5_0 x Q8_0 quantized GEMV - matches llama.cpp's approach. More...
 
void ck_test_gemv_q5_1 (const void *weight_q5_1, const float *input_f32, float *output, int rows, int cols)
 Q5_1 GEMV - matrix-vector multiply with Q5_1 weights (32-element blocks) More...
 
void ck_test_gemv_q5_k (const void *weight_q5_k, const float *input_f32, float *output, int rows, int cols)
 Q5_K GEMV - matrix-vector multiply with Q5_K weights (256-element super-blocks) More...
 
void ck_test_gemv_q6_k (const void *weight_q6k, const float *input_f32, float *output, int cols)
 Q6_K GEMV. More...
 
void ck_test_gemv_q8_0 (const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols)
 Q8_0 GEMV - matrix-vector multiply with Q8_0 weights. More...
 
void ck_test_gemv_q8_0_q8_0 (const void *weight_q8_0, const float *input_f32, float *output, int rows, int cols)
 Q8_0 x Q8_0 quantized GEMV - matches llama.cpp's approach. More...
 
void ck_test_outproj_mlp_fused_q5_0 (const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const void *w1, const void *w2, float *output, int tokens, int num_heads, int head_dim, int embed_dim, int intermediate, float eps, int w2_is_q6k)
 Test mega-fused OutProj + MLP kernel (Q5_0 weights) More...
 
void ck_test_quantize_q8_k (const float *src, void *dst, int n)
 Quantize FP32 to Q8_K (for activations) More...
 
void ck_test_rmsnorm (const float *input, const float *weight, float *output, int n_tokens, int dim, float eps)
 RMSNorm. More...
 
void ck_test_rope (float *q, float *k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)
 RoPE (Rotary Position Embedding) More...
 
void ck_test_rope_interleaved (float *q, float *k, int n_tokens, int n_heads, int n_heads_kv, int head_dim, int pos_offset, float theta)
 RoPE with interleaved format (for llama.cpp compatibility) More...
 
void ck_test_softmax (const float *input, float *output, int n)
 Softmax (simple, non-causal) More...
 
void ck_test_swiglu (const float *gate_up, float *output, int n_tokens, int intermediate_dim)
 SwiGLU activation. More...
 
void ck_test_vec_dot_q5_0_q8_0 (const void *weight_q5_0, const void *input_q8_0, float *output, int cols)
 Direct Q5_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input) More...
 
void ck_test_vec_dot_q8_0_q8_0 (const void *weight_q8_0, const void *input_q8_0, float *output, int cols)
 Direct Q8_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input) More...
 
void dequant_q4_0_row (const void *src, float *dst, size_t n_elements)
 Dequantize Q4_0 row (multiple blocks) More...
 
void dequant_q4_k_row (const void *src, float *dst, size_t n_elements)
 Dequantize Q4_K row (multiple blocks) More...
 
void dequant_q5_1_row (const void *src, float *dst, size_t n_elements)
 Dequantize Q5_1 row (multiple blocks) More...
 
void dequant_q6_k_row (const void *src, float *dst, size_t n_elements)
 Dequantize Q6_K row (multiple blocks) More...
 
void geglu_backward_fp32 (const float *x, const float *d_out, float *d_x, int n_tokens, int dim)
 
void geglu_forward_fp32 (const float *x, float *out, int tokens, int dim)
 
void gemm_nt_q4_k_q8_k (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nt_q5_0_q8_0 (const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
 Batch GEMM with Q5_0 weights and Q8_0 activations for prefill. More...
 
void gemm_nt_q5_1 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 GEMM with transposed Q5_1 weights: C = A @ B^T. More...
 
void gemm_nt_q5_k (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nt_q6_k_q8_k (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
 NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K. More...
 
void gemm_nt_q8_0_q8_0 (const void *A_q8, const void *B_q8, const float *bias, float *C, int M, int N, int K)
 gemm_nt_q8_0_q8_0 with optional bias (matches header signature) More...
 
void gemv_q4_k_q8_k (float *y, const void *W, const void *x_q8, int M, int K)
 
void gemv_q5_0 (float *y, const void *W, const float *x, int M, int K)
 Auto-dispatch GEMV for Q5_0 weights based on CPU features. More...
 
void gemv_q5_0_q8_0 (float *y, const void *W, const void *x_q8, int M, int K)
 Matrix-vector multiply with Q5_0 weights and Q8_0 input. More...
 
void gemv_q5_1 (float *y, const void *W, const float *x, int M, int K)
 Auto-dispatch GEMV. More...
 
void gemv_q5_k (float *y, const void *W, const float *x, int M, int K)
 
void gemv_q6_k_q8_k (float *y, const void *W, const void *x_q8, int M, int K)
 GEMV: y = W @ x where W is Q6_K and x is Q8_K. More...
 
void gemv_q8_0 (float *y, const void *W, const float *x, int M, int K)
 Auto-dispatch GEMV for Q8_0 weights based on CPU features. More...
 
void gemv_q8_0_q8_0 (float *y, const void *W, const void *x_q8, int M, int K)
 Matrix-vector multiply with Q8_0 weights and Q8_0 input. More...
 
void mega_fused_outproj_mlp_prefill (float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, int wo_dt, const void *w1, const float *b1, int w1_dt, const void *w2, const float *b2, int w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch)
 
size_t mega_fused_outproj_mlp_prefill_scratch_size (int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim)
 Get scratch buffer size for mega_fused_outproj_mlp_prefill. More...
 
void quantize_row_q8_0 (const float *x, void *vy, int k)
 Quantize FP32 to Q8_0 format (scalar reference) More...
 
void quantize_row_q8_k (const float *x, void *vy, int k)
 
void rmsnorm_forward (const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
 
void rope_forward_qk (float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
 
void rope_precompute_cache (float *cos_cache, float *sin_cache, int max_seq_len, int head_dim, float base)
 
void swiglu_forward (const float *input, float *output, int tokens, int dim)
 
void vec_dot_q5_0_q8_0 (int n, float *s, const void *vx, const void *vy)
 Auto-dispatch quantized dot product Q5_0 x Q8_0. More...
 
void vec_dot_q8_0_q8_0 (int n, float *s, const void *vx, const void *vy)
 Auto-dispatch quantized dot product Q8_0 x Q8_0. More...
 

Detailed Description

C-Kernel-Engine Parity Testing API Implementation.

Wraps CK kernels for parity testing against llama.cpp/ggml.

Definition in file ck_parity_api.c.

Function Documentation

◆ attention_forward_causal_head_major_gqa_flash_strided()

void attention_forward_causal_head_major_gqa_flash_strided ( const float *  q,
const float *  k,
const float *  v,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  kv_stride_tokens 
)

Flash attention forward with custom KV stride (for KV cache)

Test:

test_flash_attention.py::TestFlashAttention::test_flash_strided

test_kv_cache_attention.py::TestKVCacheAttention::test_flash_attention

Variant with configurable kv_stride_tokens for KV cache layouts where K/V may not be contiguous in memory.

After changes: make test

Definition at line 859 of file attention_kernels.c.

869 {
870  if (!q || !k || !v || !output) {
871  return;
872  }
873  if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
874  return;
875  }
876  if (kv_stride_tokens < num_tokens) {
877  return;
878  }
879 
880  const float scale = 1.0f / sqrtf((float)head_dim);
881  const int T = num_tokens;
882  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
883 
884  // Select SIMD implementation based on compile-time CPU features
885 #if defined(__AVX512F__)
886  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx512
887 #elif defined(__AVX2__)
888  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx2
889 #elif defined(__AVX__)
890  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx
891 #else
892  #define FLASH_QUERY_IMPL attention_flash_query_causal
893 #endif
894 
895  for (int h = 0; h < num_heads; ++h) {
896  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
897  const float *k_head = k + (size_t)kv_head * kv_head_stride;
898  const float *v_head = v + (size_t)kv_head * kv_head_stride;
899 
900  for (int i = 0; i < T; ++i) {
901  const float *q_vec = q + qkv_index(h, i, 0, T, aligned_head_dim);
902  float *out_vec = output + qkv_index(h, i, 0, T, aligned_head_dim);
903  FLASH_QUERY_IMPL(q_vec, k_head, v_head,
904  /*kv_tokens=*/i + 1,
905  head_dim, aligned_head_dim,
906  scale, out_vec);
907  }
908  }
909 
910 #undef FLASH_QUERY_IMPL
911 }
static size_t qkv_index(int h, int t, int d, int num_tokens, int aligned_head_dim)
#define FLASH_QUERY_IMPL

Referenced by ck_test_attention_causal().

◆ attention_forward_causal_head_major_gqa_flash_strided_sliding()

void attention_forward_causal_head_major_gqa_flash_strided_sliding ( const float *  q,
const float *  k,
const float *  v,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  kv_stride_tokens,
int  sliding_window 
)

Flash attention forward with sliding window (prefill)

Test:
test_attention.py::TestAttentionForward::test_sliding_window_prefill

Sliding-window attention for prefill: each token attends to the last W tokens. When sliding_window <= 0, behaves like regular causal attention.

After changes: make test

Definition at line 1316 of file attention_kernels.c.

1328 {
1329  if (!q || !k || !v || !output) {
1330  return;
1331  }
1332  if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
1333  return;
1334  }
1335  if (kv_stride_tokens < num_tokens) {
1336  return;
1337  }
1338 
1339  const float scale = 1.0f / sqrtf((float)head_dim);
1340  const int T = num_tokens;
1341  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
1342 
1343 #if defined(__AVX512F__)
1344  #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx512
1345 #elif defined(__AVX2__)
1346  #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx2
1347 #elif defined(__AVX__)
1348  #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx
1349 #else
1350  #define SLIDING_FLASH_IMPL attention_flash_query_sliding
1351 #endif
1352 
1353  for (int h = 0; h < num_heads; ++h) {
1354  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1355  const float *k_head = k + (size_t)kv_head * kv_head_stride;
1356  const float *v_head = v + (size_t)kv_head * kv_head_stride;
1357 
1358  for (int i = 0; i < T; ++i) {
1359  const float *q_vec = q + qkv_index(h, i, 0, T, aligned_head_dim);
1360  float *out_vec = output + qkv_index(h, i, 0, T, aligned_head_dim);
1361  SLIDING_FLASH_IMPL(q_vec, k_head, v_head,
1362  /*query_pos=*/i,
1363  /*kv_tokens=*/T,
1364  head_dim, aligned_head_dim,
1365  scale, out_vec,
1366  sliding_window);
1367  }
1368  }
1369 
1370 #undef SLIDING_FLASH_IMPL
1371 }
#define SLIDING_FLASH_IMPL

Referenced by ck_test_attention_sliding_window().

◆ attention_forward_decode_head_major_gqa_flash_sliding()

void attention_forward_decode_head_major_gqa_flash_sliding ( const float *  q_token,
const float *  k_cache,
const float *  v_cache,
float *  out_token,
int  num_heads,
int  num_kv_heads,
int  kv_tokens,
int  cache_capacity,
int  head_dim,
int  aligned_head_dim,
int  sliding_window 
)

Flash attention decode with sliding window

Test:
test_attention.py::TestAttentionForward::test_sliding_window_decode

Single query token attends to the last W tokens in the KV cache. For decode: effective_kv_tokens = min(kv_tokens, sliding_window)

After changes: make test

Definition at line 1382 of file attention_kernels.c.

1394 {
1395  if (!q_token || !k_cache || !v_cache || !out_token) {
1396  return;
1397  }
1398  if (num_heads <= 0 || num_kv_heads <= 0 || cache_capacity <= 0) {
1399  return;
1400  }
1401  if (kv_tokens <= 0 || kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
1402  return;
1403  }
1404 
1405  const float scale = 1.0f / sqrtf((float)head_dim);
1406  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
1407 
1408  // Compute effective KV tokens based on sliding window
1409  int effective_kv_tokens = kv_tokens;
1410  if (sliding_window > 0 && sliding_window < kv_tokens) {
1411  effective_kv_tokens = sliding_window;
1412  }
1413 
1414  // Guard against empty window (shouldn't happen with kv_tokens >= 1)
1415  if (effective_kv_tokens <= 0) {
1416  return;
1417  }
1418 
1419  // Offset to start reading from the last effective_kv_tokens entries
1420  int kv_start_offset = kv_tokens - effective_kv_tokens;
1421 
1422 #if defined(__AVX512F__)
1423  #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx512
1424 #elif defined(__AVX2__)
1425  #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx2
1426 #elif defined(__AVX__)
1427  #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx
1428 #else
1429  #define SLIDING_DECODE_IMPL attention_flash_query_sliding
1430 #endif
1431 
1432  for (int h = 0; h < num_heads; ++h) {
1433  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1434  const float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
1435  // Offset K/V pointer to start from the first token in the sliding window
1436  const float *k_head = k_cache + (size_t)kv_head * head_stride
1437  + (size_t)kv_start_offset * (size_t)aligned_head_dim;
1438  const float *v_head = v_cache + (size_t)kv_head * head_stride
1439  + (size_t)kv_start_offset * (size_t)aligned_head_dim;
1440  float *out_head = out_token + (size_t)h * (size_t)aligned_head_dim;
1441 
1442  // Use query_pos relative to the windowed KV (last token = effective_kv_tokens - 1)
1443  // sliding_window = 0 since we've already windowed via K/V pointer offset
1444  SLIDING_DECODE_IMPL(q_head, k_head, v_head,
1445  /*query_pos=*/effective_kv_tokens - 1,
1446  /*kv_tokens=*/effective_kv_tokens,
1447  head_dim, aligned_head_dim,
1448  scale, out_head,
1449  /*sliding_window=*/0);
1450  }
1451 
1452 #undef SLIDING_DECODE_IMPL
1453 }
#define SLIDING_DECODE_IMPL

Referenced by ck_test_attention_decode_sliding().

◆ ck_get_block_q4_k_size()

int ck_get_block_q4_k_size ( void  )

Get Q4_K block size in bytes.

Definition at line 961 of file ck_parity_api.c.

962 {
963  return sizeof(block_q4_K);
964 }

◆ ck_get_block_q5_1_size()

int ck_get_block_q5_1_size ( void  )

Get Q5_1 block size in bytes (24 bytes per 32 weights)

Definition at line 986 of file ck_parity_api.c.

987 {
988  return sizeof(block_q5_1);
989 }

◆ ck_get_block_q5_k_size()

int ck_get_block_q5_k_size ( void  )

Get Q5_K block size in bytes (176 bytes per 256 weights)

Definition at line 981 of file ck_parity_api.c.

982 {
983  return sizeof(block_q5_K);
984 }

◆ ck_get_block_q6_k_size()

int ck_get_block_q6_k_size ( void  )

Get Q6_K block size in bytes.

Definition at line 966 of file ck_parity_api.c.

967 {
968  return sizeof(block_q6_K);
969 }

◆ ck_get_block_q8_k_size()

int ck_get_block_q8_k_size ( void  )

Get Q8_K block size in bytes.

Definition at line 971 of file ck_parity_api.c.

972 {
973  return sizeof(block_q8_K);
974 }

◆ ck_get_qk5_1()

int ck_get_qk5_1 ( void  )

Get QK5_1 (elements per Q5_1 block)

Definition at line 991 of file ck_parity_api.c.

992 {
993  return QK5_1;
994 }
#define QK5_1
Definition: ckernel_quant.h:84

References QK5_1.

◆ ck_get_qk_k()

int ck_get_qk_k ( void  )

Get QK_K (elements per super-block)

Definition at line 976 of file ck_parity_api.c.

977 {
978  return QK_K;
979 }
#define QK_K

References QK_K.

◆ ck_test_attention_causal()

void ck_test_attention_causal ( const float *  q,
const float *  k,
const float *  v,
float *  out,
int  num_heads,
int  num_kv_heads,
int  tokens,
int  seq_len,
int  head_dim 
)

Multi-head causal attention for prefill (head-major layout)

Layout (head-major, matches llama.cpp test): Q: [num_heads, tokens, head_dim] K: [num_kv_heads, seq_len, head_dim] V: [num_kv_heads, seq_len, head_dim] out: [num_heads, tokens, head_dim]

Supports GQA (grouped-query attention) where num_heads > num_kv_heads. Causal masking: token t can only attend to positions 0..t (inclusive).

Parameters
qQuery [num_heads, tokens, head_dim]
kKey [num_kv_heads, seq_len, head_dim]
vValue [num_kv_heads, seq_len, head_dim]
outOutput [num_heads, tokens, head_dim]
num_headsNumber of query heads
num_kv_headsNumber of key/value heads (for GQA)
tokensNumber of query tokens
seq_lenKey/value sequence length (for prefill: seq_len == tokens)
head_dimDimension per head

Definition at line 736 of file ck_parity_api.c.

745 {
746  /* For prefill, seq_len == tokens, and kv_stride == tokens.
747  * The CK kernel expects strided KV layout with kv_stride_tokens parameter.
748  * For parity testing with contiguous tensors, kv_stride = seq_len.
749  */
751  q, k, v, out,
752  num_heads, num_kv_heads, tokens,
753  head_dim, head_dim, /* aligned_head_dim = head_dim for testing */
754  seq_len /* kv_stride_tokens = seq_len for contiguous KV */
755  );
756 }
void attention_forward_causal_head_major_gqa_flash_strided(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)

References attention_forward_causal_head_major_gqa_flash_strided().

◆ ck_test_attention_decode_sliding()

void ck_test_attention_decode_sliding ( const float *  q_token,
const float *  k_cache,
const float *  v_cache,
float *  out_token,
int  num_heads,
int  num_kv_heads,
int  kv_tokens,
int  cache_capacity,
int  head_dim,
int  sliding_window 
)

Test sliding-window attention (decode mode)

Single query token attending to KV cache with sliding window.

Definition at line 794 of file ck_parity_api.c.

804 {
806  q_token, k_cache, v_cache, out_token,
807  num_heads, num_kv_heads,
808  kv_tokens, cache_capacity, head_dim, head_dim,
809  sliding_window
810  );
811 }
void attention_forward_decode_head_major_gqa_flash_sliding(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window)

References attention_forward_decode_head_major_gqa_flash_sliding().

◆ ck_test_attention_sliding_window()

void ck_test_attention_sliding_window ( const float *  q,
const float *  k,
const float *  v,
float *  out,
int  num_heads,
int  num_kv_heads,
int  tokens,
int  seq_len,
int  head_dim,
int  sliding_window 
)

Test sliding-window attention (prefill)

Layout (head-major, matching CK-Engine): Q: [num_heads, tokens, head_dim] K: [num_kv_heads, seq_len, head_dim] V: [num_kv_heads, seq_len, head_dim] out: [num_heads, tokens, head_dim]

Each token attends only to the last sliding_window tokens.

Definition at line 769 of file ck_parity_api.c.

779 {
781  q, k, v, out,
782  num_heads, num_kv_heads, tokens,
783  head_dim, head_dim, /* aligned_head_dim = head_dim for testing */
784  seq_len, /* kv_stride_tokens = seq_len for contiguous KV */
785  sliding_window
786  );
787 }
void attention_forward_causal_head_major_gqa_flash_strided_sliding(const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window)

References attention_forward_causal_head_major_gqa_flash_strided_sliding().

◆ ck_test_dequant_q4_0()

void ck_test_dequant_q4_0 ( const void *  src,
float *  dst,
int  n 
)

Dequantize Q4_0 data to FP32.

Definition at line 122 of file ck_parity_api.c.

123 {
124  dequant_q4_0_row(src, dst, (size_t)n);
125 }
void dequant_q4_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_0 row (multiple blocks)

References dequant_q4_0_row().

◆ ck_test_dequant_q4_k()

void ck_test_dequant_q4_k ( const void *  src,
float *  dst,
int  n 
)

Dequantize Q4_K data to FP32.

Parameters
srcInput Q4_K blocks
dstOutput FP32 values
nNumber of elements (must be multiple of 256)

Definition at line 112 of file ck_parity_api.c.

113 {
114  dequant_q4_k_row(src, dst, (size_t)n);
115 }
void dequant_q4_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_K row (multiple blocks)

References dequant_q4_k_row().

◆ ck_test_dequant_q5_1()

void ck_test_dequant_q5_1 ( const void *  src,
float *  dst,
int  n 
)

Dequantize Q5_1 data to FP32.

Definition at line 127 of file ck_parity_api.c.

128 {
129  dequant_q5_1_row(src, dst, (size_t)n);
130 }
void dequant_q5_1_row(const void *src, float *dst, size_t n_elements)
Dequantize Q5_1 row (multiple blocks)

References dequant_q5_1_row().

◆ ck_test_dequant_q6_k()

void ck_test_dequant_q6_k ( const void *  src,
float *  dst,
int  n 
)

Dequantize Q6_K data to FP32.

Definition at line 117 of file ck_parity_api.c.

118 {
119  dequant_q6_k_row(src, dst, (size_t)n);
120 }
void dequant_q6_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q6_K row (multiple blocks)

References dequant_q6_k_row().

◆ ck_test_geglu()

void ck_test_geglu ( const float *  x,
float *  out,
int  n_tokens,
int  dim 
)

Test GeGLU activation.

Computes: output = GELU(a) * b where input contains [a, b] concatenated along the last dimension.

Definition at line 819 of file ck_parity_api.c.

823 {
824  geglu_forward_fp32(x, out, n_tokens, dim);
825 }
void geglu_forward_fp32(const float *x, float *out, int tokens, int dim)
Definition: gelu_kernels.c:623

References geglu_forward_fp32().

◆ ck_test_geglu_backward()

void ck_test_geglu_backward ( const float *  x,
const float *  d_out,
float *  d_x,
int  n_tokens,
int  dim 
)

Test GeGLU backward.

Computes gradients dL/dx given dL/d(out) where out = GELU(a) * b

Definition at line 832 of file ck_parity_api.c.

837 {
838  geglu_backward_fp32(x, d_out, d_x, n_tokens, dim);
839 }
void geglu_backward_fp32(const float *x, const float *d_out, float *d_x, int n_tokens, int dim)
Definition: gelu_kernels.c:843

References geglu_backward_fp32().

◆ ck_test_gemm_q4_k()

void ck_test_gemm_q4_k ( const void *  weight_q4k,
const float *  input_f32,
float *  output,
int  rows,
int  cols,
int  n_tokens 
)

Q4_K GEMM - batched matrix multiply with quantized weights.

Computes: output[t,r] = sum_k(weight[r,k] * input[t,k])

Parameters
weight_q4kQ4_K quantized weights [rows, cols]
input_f32FP32 input [n_tokens, cols]
outputFP32 output [n_tokens, rows]
rowsNumber of output rows
colsNumber of columns (must be multiple of 256)
n_tokensBatch size

Definition at line 392 of file ck_parity_api.c.

396 {
397  /* Allocate Q8_K buffer for quantized activations */
398  int n_blocks_per_row = cols / CK_QK_K;
399  block_q8_K *q8_data = (block_q8_K *)malloc(n_tokens * n_blocks_per_row * sizeof(block_q8_K));
400  if (!q8_data) {
401  memset(output, 0, n_tokens * rows * sizeof(float));
402  return;
403  }
404 
405  /* Quantize all input tokens */
406  for (int t = 0; t < n_tokens; t++) {
407  quantize_row_q8_k(input_f32 + t * cols,
408  q8_data + t * n_blocks_per_row, cols);
409  }
410 
411  /* Use gemm_nt_q4_k_q8_k: C[M,N] = A[M,K] * B[N,K]^T
412  * Our layout: output[n_tokens, rows] = input[n_tokens, cols] * weight[rows, cols]^T
413  * So: M = n_tokens, N = rows, K = cols
414  */
415  gemm_nt_q4_k_q8_k(q8_data, weight_q4k, NULL, output, n_tokens, rows, cols);
416 
417  free(q8_data);
418 }
void quantize_row_q8_k(const float *x, void *vy, int k)
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
#define CK_QK_K
Definition: ck_parity_api.h:28

References CK_QK_K, gemm_nt_q4_k_q8_k(), and quantize_row_q8_k().

◆ ck_test_gemm_q5_0()

void ck_test_gemm_q5_0 ( const void *  weight_q5_0,
const float *  input_f32,
float *  output,
int  rows,
int  cols,
int  n_tokens 
)

Test Q5_0 x Q8_0 GEMM (batch matrix multiply)

Q5_0 GEMM - batched matrix multiply with Q5_0 weights (32-element blocks)

Used for MLP W1 (gate/up projection) and attention Q/K with Q5_0 weights.

Definition at line 491 of file ck_parity_api.c.

495 {
496  /* Allocate Q8_0 buffer for quantized activations */
497  int n_blocks_per_row = cols / CK_QK8_0;
498  block_q8_0 *q8_data = (block_q8_0 *)malloc(n_tokens * n_blocks_per_row * sizeof(block_q8_0));
499  if (!q8_data) {
500  memset(output, 0, n_tokens * rows * sizeof(float));
501  return;
502  }
503 
504  /* Quantize all input tokens */
505  for (int t = 0; t < n_tokens; t++) {
506  quantize_row_q8_0(input_f32 + t * cols,
507  q8_data + t * n_blocks_per_row, cols);
508  }
509 
510  /* Use gemm_nt_q5_0_q8_0: C[M,N] = A[M,K] * B[N,K]^T
511  * Our layout: output[n_tokens, rows] = input[n_tokens, cols] * weight[rows, cols]^T
512  * So: M = n_tokens, N = rows, K = cols
513  */
514  gemm_nt_q5_0_q8_0(q8_data, weight_q5_0, NULL, output, n_tokens, rows, cols);
515 
516  free(q8_data);
517 }
void gemm_nt_q5_0_q8_0(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)
#define CK_QK8_0
Definition: ck_parity_api.h:30

References CK_QK8_0, gemm_nt_q5_0_q8_0(), and quantize_row_q8_0().

◆ ck_test_gemm_q5_1()

void ck_test_gemm_q5_1 ( const void *  weight_q5_1,
const float *  input_f32,
float *  output,
int  rows,
int  cols,
int  n_tokens 
)

Test Q5_1 x Q8_0 GEMM (batch matrix multiply)

Q5_1 GEMM - batched matrix multiply with Q5_1 weights (32-element blocks)

Used for MLP W1 (gate/up projection) and attention Q/K with Q5_1 weights. gemm_nt_q5_1 expects FP32 activations (not quantized).

Definition at line 542 of file ck_parity_api.c.

546 {
547  /* gemm_nt_q5_1 expects FP32 activations, not quantized.
548  * Pass input_f32 directly as-is (already FP32).
549  */
550  gemm_nt_q5_1(input_f32, weight_q5_1, NULL, output, n_tokens, rows, cols);
551 }
void gemm_nt_q5_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_1 weights: C = A @ B^T.

References gemm_nt_q5_1().

◆ ck_test_gemm_q5_k()

void ck_test_gemm_q5_k ( const void *  weight_q5_k,
const float *  input_f32,
float *  output,
int  rows,
int  cols,
int  n_tokens 
)

Test Q5_K x Q8_K GEMM (batch matrix multiply)

Q5_K GEMM - batched matrix multiply with Q5_K weights (256-element super-blocks)

Used for MLP W1 (gate/up projection) and attention Q/K with Q5_K weights. gemm_nt_q5_k expects FP32 activations (not quantized).

Definition at line 525 of file ck_parity_api.c.

529 {
530  /* gemm_nt_q5_k expects FP32 activations, not quantized.
531  * Pass input_f32 directly as-is (already FP32).
532  */
533  gemm_nt_q5_k(input_f32, weight_q5_k, NULL, output, n_tokens, rows, cols);
534 }
void gemm_nt_q5_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)

References gemm_nt_q5_k().

◆ ck_test_gemm_q6_k()

void ck_test_gemm_q6_k ( const void *  weight_q6k,
const float *  input_f32,
float *  output,
int  rows,
int  cols,
int  n_tokens 
)

Test Q6_K x Q8_K GEMM (batch matrix multiply)

Q6_K GEMM - batched matrix multiply with Q6_K weights.

Used for MLP W2 (down projection) with Q6_K weights.

Definition at line 425 of file ck_parity_api.c.

429 {
430  /* Allocate Q8_K buffer for quantized activations */
431  int n_blocks_per_row = cols / CK_QK_K;
432  block_q8_K *q8_data = (block_q8_K *)malloc(n_tokens * n_blocks_per_row * sizeof(block_q8_K));
433  if (!q8_data) {
434  memset(output, 0, n_tokens * rows * sizeof(float));
435  return;
436  }
437 
438  /* Quantize all input tokens */
439  for (int t = 0; t < n_tokens; t++) {
440  quantize_row_q8_k(input_f32 + t * cols,
441  q8_data + t * n_blocks_per_row, cols);
442  }
443 
444  /* Use gemm_nt_q6_k_q8_k: C[M,N] = A[M,K] * B[N,K]^T
445  * Our layout: output[n_tokens, rows] = input[n_tokens, cols] * weight[rows, cols]^T
446  * So: M = n_tokens, N = rows, K = cols
447  */
448  gemm_nt_q6_k_q8_k(q8_data, weight_q6k, NULL, output, n_tokens, rows, cols);
449 
450  free(q8_data);
451 }
void gemm_nt_q6_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.

References CK_QK_K, gemm_nt_q6_k_q8_k(), and quantize_row_q8_k().

◆ ck_test_gemm_q8_0()

void ck_test_gemm_q8_0 ( const void *  weight_q8_0,
const float *  input_f32,
float *  output,
int  rows,
int  cols,
int  n_tokens 
)

Test Q8_0 x Q8_0 GEMM (batch matrix multiply)

Q8_0 GEMM - batched matrix multiply with Q8_0 weights (32-element blocks)

Used for attention V projection with Q8_0 weights.

Definition at line 458 of file ck_parity_api.c.

462 {
463  /* Allocate Q8_0 buffer for quantized activations */
464  int n_blocks_per_row = cols / CK_QK8_0;
465  block_q8_0 *q8_data = (block_q8_0 *)malloc(n_tokens * n_blocks_per_row * sizeof(block_q8_0));
466  if (!q8_data) {
467  memset(output, 0, n_tokens * rows * sizeof(float));
468  return;
469  }
470 
471  /* Quantize all input tokens */
472  for (int t = 0; t < n_tokens; t++) {
473  quantize_row_q8_0(input_f32 + t * cols,
474  q8_data + t * n_blocks_per_row, cols);
475  }
476 
477  /* Use gemm_nt_q8_0_q8_0: C[M,N] = A[M,K] * B[N,K]^T
478  * Our layout: output[n_tokens, rows] = input[n_tokens, cols] * weight[rows, cols]^T
479  * So: M = n_tokens, N = rows, K = cols
480  */
481  gemm_nt_q8_0_q8_0(q8_data, weight_q8_0, NULL, output, n_tokens, rows, cols);
482 
483  free(q8_data);
484 }
void gemm_nt_q8_0_q8_0(const void *A_q8, const void *B_q8, const float *bias, float *C, int M, int N, int K)
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)

References CK_QK8_0, gemm_nt_q8_0_q8_0(), and quantize_row_q8_0().

◆ ck_test_gemv_q4_k()

void ck_test_gemv_q4_k ( const void *  weight_q4k,
const float *  input_f32,
float *  output,
int  cols 
)

Q4_K GEMV - dot product of quantized weights and FP32 input.

Internally quantizes input to Q8_K, then computes dot product.

Parameters
weight_q4kQ4_K quantized weights [cols]
input_f32FP32 input vector [cols]
outputOutput scalar [1]
colsNumber of columns (must be multiple of 256)

Definition at line 145 of file ck_parity_api.c.

149 {
150  /* Allocate Q8_K buffer for quantized activations */
151  int n_blocks = cols / CK_QK_K;
152  block_q8_K *q8_data = (block_q8_K *)malloc(n_blocks * sizeof(block_q8_K));
153  if (!q8_data) {
154  *output = 0.0f;
155  return;
156  }
157 
158  /* Quantize input to Q8_K */
159  quantize_row_q8_k(input_f32, q8_data, cols);
160 
161  /* Compute dot product using GEMV with M=1 */
162  gemv_q4_k_q8_k(output, weight_q4k, q8_data, 1, cols);
163 
164  free(q8_data);
165 }
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)

References CK_QK_K, gemv_q4_k_q8_k(), and quantize_row_q8_k().

◆ ck_test_gemv_q5_0()

void ck_test_gemv_q5_0 ( const void *  weight_q5_0,
const float *  input_f32,
float *  output,
int  rows,
int  cols 
)

Q5_0 GEMV - matrix-vector multiply with Q5_0 weights.

Parameters
weight_q5_0Q5_0 quantized weights [rows * cols]
input_f32FP32 input vector [cols]
outputFP32 output vector [rows]
rowsNumber of output rows
colsNumber of columns (must be multiple of 32)

Definition at line 192 of file ck_parity_api.c.

196 {
197  /* Match llama.cpp's test_gemv_q5_0:
198  * 1. Quantize input to Q8_0 format
199  * 2. Use quantized dot product (vec_dot_q5_0_q8_0)
200  *
201  * This ensures parity with llama.cpp which always uses the
202  * quantized path, NOT the FP32 dequantization path.
203  */
204  int n_blocks = cols / CK_QK8_0;
205  block_q8_0 *q8_data = (block_q8_0 *)malloc(n_blocks * sizeof(block_q8_0));
206  if (!q8_data) {
207  for (int r = 0; r < rows; r++) output[r] = 0.0f;
208  return;
209  }
210 
211  /* Quantize input to Q8_0 */
212  quantize_row_q8_0(input_f32, q8_data, cols);
213 
214  /* Call the quantized GEMV kernel (same as ck_test_gemv_q5_0_q8_0) */
215  gemv_q5_0_q8_0(output, weight_q5_0, q8_data, rows, cols);
216 
217  free(q8_data);
218 }
void gemv_q5_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q5_0 weights and Q8_0 input.

References CK_QK8_0, gemv_q5_0_q8_0(), and quantize_row_q8_0().

◆ ck_test_gemv_q5_0_q8_0()

void ck_test_gemv_q5_0_q8_0 ( const void *  weight_q5_0,
const float *  input_f32,
float *  output,
int  rows,
int  cols 
)

Q5_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.

This version quantizes the input to Q8_0 first, then uses integer dot products (like llama.cpp does). Use this for parity testing.

Parameters
weight_q5_0Q5_0 quantized weights [rows * cols]
input_f32FP32 input vector [cols] - will be quantized to Q8_0
outputFP32 output vector [rows]
rowsNumber of output rows
colsNumber of columns (must be multiple of 32)

Definition at line 248 of file ck_parity_api.c.

252 {
253  /* This matches llama.cpp's approach:
254  * 1. Quantize input to Q8_0 format
255  * 2. Use quantized dot product (integer math)
256  * 3. Scale at the end
257  */
258  int n_blocks = cols / CK_QK8_0;
259  block_q8_0 *q8_data = (block_q8_0 *)malloc(n_blocks * sizeof(block_q8_0));
260  if (!q8_data) {
261  for (int r = 0; r < rows; r++) output[r] = 0.0f;
262  return;
263  }
264 
265  /* Quantize input to Q8_0 */
266  quantize_row_q8_0(input_f32, q8_data, cols);
267 
268  /* Call the quantized GEMV kernel */
269  gemv_q5_0_q8_0(output, weight_q5_0, q8_data, rows, cols);
270 
271  free(q8_data);
272 }

References CK_QK8_0, gemv_q5_0_q8_0(), and quantize_row_q8_0().

◆ ck_test_gemv_q5_1()

void ck_test_gemv_q5_1 ( const void *  weight_q5_1,
const float *  input_f32,
float *  output,
int  rows,
int  cols 
)

Q5_1 GEMV - matrix-vector multiply with Q5_1 weights (32-element blocks)

Uses Q8_0 for activations (like Q5_0).

Parameters
weight_q5_1Q5_1 quantized weights [rows * cols]
input_f32FP32 input vector [cols]
outputFP32 output vector [rows]
rowsNumber of output rows
colsNumber of columns (must be multiple of 32)

Definition at line 333 of file ck_parity_api.c.

337 {
338  /*
339  * IMPORTANT: gemv_q5_1() expects raw FP32 activations, NOT pre-quantized Q8_0.
340  * See comment in ck_test_gemv_q5_k() above for explanation.
341  */
342  for (int r = 0; r < rows; r++) {
343  gemv_q5_1(&output[r],
344  (const char *)weight_q5_1 + r * (cols / QK5_1) * sizeof(block_q5_1),
345  input_f32, 1, cols);
346  }
347 }
void gemv_q5_1(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.

References gemv_q5_1(), and QK5_1.

◆ ck_test_gemv_q5_k()

void ck_test_gemv_q5_k ( const void *  weight_q5_k,
const float *  input_f32,
float *  output,
int  rows,
int  cols 
)

Q5_K GEMV - matrix-vector multiply with Q5_K weights (256-element super-blocks)

Uses Q8_K for activations (like Q4_K).

Parameters
weight_q5_kQ5_K quantized weights [rows * cols]
input_f32FP32 input vector [cols]
outputFP32 output vector [rows]
rowsNumber of output rows
colsNumber of columns (must be multiple of 256)

Definition at line 301 of file ck_parity_api.c.

305 {
306  /*
307  * IMPORTANT: gemv_q5_k() expects raw FP32 activations, NOT pre-quantized Q8_K.
308  *
309  * This is different from gemv_q4_k_q8_k() and gemv_q5_0_q8_0() which are
310  * "quantized dot product" kernels that take block_q8_K or block_q8_0 input.
311  *
312  * WHY THIS IS ERROR-PRONE:
313  * When copying from ck_test_gemv_q5_0() (which calls gemv_q5_0_q8_0),
314  * it is natural to assume Q5_K also needs pre-quantization. But the
315  * function name tells you: gemv_q5_k() takes float*, while
316  * gemv_q5_0_q8_0() takes block_q8_0*. If the kernel name does not
317  * have "_q8_0" or "_q8_k" suffix, it expects FP32 input.
318  *
319  * PARITY NOTE:
320  * llama.cpp reference uses ggml_vec_dot_q5_K_q8_K which quantizes
321  * the input to Q8_K internally. Our FP32 path will have slightly
322  * different numerical results. Use tolerance ~1e-2 for comparison.
323  * To get exact parity, implement gemv_q5_k_q8_k() (quantized dot product).
324  */
325  for (int r = 0; r < rows; r++) {
326  gemv_q5_k(&output[r],
327  (const char *)weight_q5_k + r * (cols / CK_QK_K) * sizeof(block_q5_K),
328  input_f32, 1, cols);
329  }
330 }
void gemv_q5_k(float *y, const void *W, const float *x, int M, int K)

References CK_QK_K, and gemv_q5_k().

◆ ck_test_gemv_q6_k()

void ck_test_gemv_q6_k ( const void *  weight_q6k,
const float *  input_f32,
float *  output,
int  cols 
)

Q6_K GEMV.

Definition at line 167 of file ck_parity_api.c.

171 {
172  /* Q6_K GEMV is not yet implemented in CK - provide reference impl */
173  /* For now, dequantize and compute in FP32 */
174  float *weight_f32 = (float *)malloc(cols * sizeof(float));
175  if (!weight_f32) {
176  *output = 0.0f;
177  return;
178  }
179 
180  dequant_q6_k_row(weight_q6k, weight_f32, cols);
181 
182  /* Dot product in FP32 */
183  double sum = 0.0;
184  for (int i = 0; i < cols; i++) {
185  sum += (double)weight_f32[i] * (double)input_f32[i];
186  }
187  *output = (float)sum;
188 
189  free(weight_f32);
190 }

References dequant_q6_k_row().

◆ ck_test_gemv_q8_0()

void ck_test_gemv_q8_0 ( const void *  weight_q8_0,
const float *  input_f32,
float *  output,
int  rows,
int  cols 
)

Q8_0 GEMV - matrix-vector multiply with Q8_0 weights.

Parameters
weight_q8_0Q8_0 quantized weights [rows * cols]
input_f32FP32 input vector [cols]
outputFP32 output vector [rows]
rowsNumber of output rows
colsNumber of columns (must be multiple of 32)

Definition at line 220 of file ck_parity_api.c.

224 {
225  /* Match llama.cpp's test_gemv_q8_0:
226  * 1. Quantize input to Q8_0 format
227  * 2. Use quantized dot product (vec_dot_q8_0_q8_0)
228  *
229  * This ensures parity with llama.cpp which always uses the
230  * quantized path, NOT the FP32 dequantization path.
231  */
232  int n_blocks = cols / CK_QK8_0;
233  block_q8_0 *q8_data = (block_q8_0 *)malloc(n_blocks * sizeof(block_q8_0));
234  if (!q8_data) {
235  for (int r = 0; r < rows; r++) output[r] = 0.0f;
236  return;
237  }
238 
239  /* Quantize input to Q8_0 */
240  quantize_row_q8_0(input_f32, q8_data, cols);
241 
242  /* Call the quantized GEMV kernel (same as ck_test_gemv_q8_0_q8_0) */
243  gemv_q8_0_q8_0(output, weight_q8_0, q8_data, rows, cols);
244 
245  free(q8_data);
246 }
void gemv_q8_0_q8_0(float *y, const void *W, const void *x_q8, int M, int K)
Matrix-vector multiply with Q8_0 weights and Q8_0 input.

References CK_QK8_0, gemv_q8_0_q8_0(), and quantize_row_q8_0().

◆ ck_test_gemv_q8_0_q8_0()

void ck_test_gemv_q8_0_q8_0 ( const void *  weight_q8_0,
const float *  input_f32,
float *  output,
int  rows,
int  cols 
)

Q8_0 x Q8_0 quantized GEMV - matches llama.cpp's approach.

This version quantizes the input to Q8_0 first, then uses integer dot products (like llama.cpp does). Use this for parity testing.

Parameters
weight_q8_0Q8_0 quantized weights [rows * cols]
input_f32FP32 input vector [cols] - will be quantized to Q8_0
outputFP32 output vector [rows]
rowsNumber of output rows
colsNumber of columns (must be multiple of 32)

Definition at line 274 of file ck_parity_api.c.

278 {
279  /* This matches llama.cpp's approach:
280  * 1. Quantize input to Q8_0 format
281  * 2. Use quantized dot product (integer math)
282  * 3. Scale at the end
283  */
284  int n_blocks = cols / CK_QK8_0;
285  block_q8_0 *q8_data = (block_q8_0 *)malloc(n_blocks * sizeof(block_q8_0));
286  if (!q8_data) {
287  for (int r = 0; r < rows; r++) output[r] = 0.0f;
288  return;
289  }
290 
291  /* Quantize input to Q8_0 */
292  quantize_row_q8_0(input_f32, q8_data, cols);
293 
294  /* Call the quantized GEMV kernel */
295  gemv_q8_0_q8_0(output, weight_q8_0, q8_data, rows, cols);
296 
297  free(q8_data);
298 }

References CK_QK8_0, gemv_q8_0_q8_0(), and quantize_row_q8_0().

◆ ck_test_outproj_mlp_fused_q5_0()

void ck_test_outproj_mlp_fused_q5_0 ( const float *  attn_out,
const float *  residual,
const float *  ln2_gamma,
const void *  wo,
const void *  w1,
const void *  w2,
float *  output,
int  tokens,
int  num_heads,
int  head_dim,
int  embed_dim,
int  intermediate,
float  eps,
int  w2_is_q6k 
)

Test mega-fused OutProj + MLP kernel (Q5_0 weights)

This is a simplified wrapper for parity testing that:

  • Uses Q5_0 for W_o and W1 weights
  • Uses Q4_K for W2 weights
  • Allocates scratch internally
Parameters
attn_outAttention output [num_heads, tokens, head_dim] (FP32, head-major)
residualResidual input [tokens, embed_dim] (FP32)
ln2_gammaRMSNorm gamma [embed_dim] (FP32)
woOutProj weights [embed_dim, embed_dim] (Q5_0)
w1MLP W1 weights [2*intermediate, embed_dim] (Q5_0)
w2MLP W2 weights [embed_dim, intermediate] (Q4_K or Q6_K)
outputOutput [tokens, embed_dim] (FP32)
tokensNumber of tokens
num_headsNumber of attention heads
head_dimDimension per head
embed_dimEmbedding dimension (= num_heads * head_dim)
intermediateMLP intermediate dimension
epsRMSNorm epsilon
w2_is_q6kIf true, W2 is Q6_K; if false, W2 is Q4_K

Definition at line 894 of file ck_parity_api.c.

909 {
910  /* CK uses dtype enum: CK_DT_Q5_0 = 11, CK_DT_Q4_K = 7, CK_DT_Q6_K = 8 */
911  const int CK_DT_Q5_0_VAL = 11;
912  const int CK_DT_Q4_K_VAL = 7;
913  const int CK_DT_Q6_K_VAL = 8;
914 
915  /* For parity testing, aligned = actual (no padding) */
916  int aligned_embed_dim = embed_dim;
917  int aligned_head_dim = head_dim;
918  int aligned_intermediate = intermediate;
919 
920  /* Ensure intermediate is multiple of 256 (QK_K) for K-quants */
921  if ((intermediate % 256) != 0) {
922  aligned_intermediate = ((intermediate + 255) / 256) * 256;
923  }
924 
925  /* Allocate scratch */
926  size_t scratch_size = mega_fused_outproj_mlp_prefill_scratch_size(
927  tokens, aligned_embed_dim, num_heads, aligned_head_dim, aligned_intermediate);
928 
929  void *scratch = malloc(scratch_size);
930  if (!scratch) {
931  return;
932  }
933 
934  /* Call the mega-fused kernel */
936  output,
937  attn_out,
938  residual,
939  ln2_gamma,
940  wo, NULL, CK_DT_Q5_0_VAL, /* W_o with Q5_0 */
941  w1, NULL, CK_DT_Q5_0_VAL, /* W1 with Q5_0 */
942  w2, NULL, w2_is_q6k ? CK_DT_Q6_K_VAL : CK_DT_Q4_K_VAL, /* W2 with Q4_K or Q6_K */
943  tokens,
944  embed_dim,
945  aligned_embed_dim,
946  num_heads,
947  aligned_head_dim,
948  intermediate,
949  aligned_intermediate,
950  eps,
951  scratch
952  );
953 
954  free(scratch);
955 }
void mega_fused_outproj_mlp_prefill(float *output, const float *attn_out, const float *residual, const float *ln2_gamma, const void *wo, const float *bo, int wo_dt, const void *w1, const float *b1, int w1_dt, const void *w2, const float *b2, int w2_dt, int tokens, int embed_dim, int aligned_embed_dim, int num_heads, int aligned_head_dim, int intermediate_dim, int aligned_intermediate_dim, float eps, void *scratch)
size_t mega_fused_outproj_mlp_prefill_scratch_size(int tokens, int aligned_embed_dim, int num_heads, int aligned_head_dim, int aligned_intermediate_dim)
Get scratch buffer size for mega_fused_outproj_mlp_prefill.

References mega_fused_outproj_mlp_prefill(), and mega_fused_outproj_mlp_prefill_scratch_size().

◆ ck_test_quantize_q8_k()

void ck_test_quantize_q8_k ( const float *  src,
void *  dst,
int  n 
)

Quantize FP32 to Q8_K (for activations)

Parameters
srcInput FP32 values
dstOutput Q8_K blocks
nNumber of elements (must be multiple of 256)

Definition at line 136 of file ck_parity_api.c.

137 {
138  quantize_row_q8_k(src, dst, n);
139 }

References quantize_row_q8_k().

◆ ck_test_rmsnorm()

void ck_test_rmsnorm ( const float *  input,
const float *  weight,
float *  output,
int  n_tokens,
int  dim,
float  eps 
)

RMSNorm.

Computes: output = (input / rms(input)) * weight where rms(x) = sqrt(mean(x^2) + eps)

Parameters
inputInput tensor [n_tokens, dim]
weightNormalization weights [dim]
outputOutput tensor [n_tokens, dim]
n_tokensNumber of tokens
dimHidden dimension
epsEpsilon for numerical stability

Definition at line 557 of file ck_parity_api.c.

561 {
562  /* CK rmsnorm_forward has aligned_embed_dim parameter
563  * For testing, use dim as aligned_embed_dim (no padding) */
564  rmsnorm_forward(input, weight, output, NULL, n_tokens, dim, dim, eps);
565 }
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)

References rmsnorm_forward().

◆ ck_test_rope()

void ck_test_rope ( float *  q,
float *  k,
int  n_tokens,
int  n_heads,
int  n_heads_kv,
int  head_dim,
int  pos_offset,
float  theta 
)

RoPE (Rotary Position Embedding)

Applies rotary position embeddings to Q and K tensors.

NOTE: CK uses rotate-half format (split first/second halves) while some implementations use interleaved format. The test harness should account for this.

Parameters
qQuery tensor [n_tokens, n_heads * head_dim], modified in-place
kKey tensor [n_tokens, n_heads_kv * head_dim], modified in-place
n_tokensNumber of tokens
n_headsNumber of query heads
n_heads_kvNumber of key/value heads
head_dimDimension per head
pos_offsetStarting position for RoPE
thetaRoPE base frequency (typically 10000.0)

Definition at line 567 of file ck_parity_api.c.

570 {
571  /* Precompute cos/sin cache */
572  int half_dim = head_dim / 2;
573  int max_seq = pos_offset + n_tokens;
574 
575  float *cos_cache = (float *)malloc(max_seq * half_dim * sizeof(float));
576  float *sin_cache = (float *)malloc(max_seq * half_dim * sizeof(float));
577  if (!cos_cache || !sin_cache) {
578  free(cos_cache);
579  free(sin_cache);
580  return;
581  }
582 
583  rope_precompute_cache(cos_cache, sin_cache, max_seq, head_dim, theta);
584 
585  /* CK RoPE expects layout [num_heads, num_tokens, head_dim]
586  * Reshape from [n_tokens, n_heads * head_dim] to [n_heads, n_tokens, head_dim]
587  */
588  float *q_reorder = (float *)malloc(n_heads * n_tokens * head_dim * sizeof(float));
589  float *k_reorder = (float *)malloc(n_heads_kv * n_tokens * head_dim * sizeof(float));
590 
591  if (q_reorder && k_reorder) {
592  /* Reorder Q: [T, H*D] -> [H, T, D] */
593  for (int t = 0; t < n_tokens; t++) {
594  for (int h = 0; h < n_heads; h++) {
595  for (int d = 0; d < head_dim; d++) {
596  q_reorder[h * n_tokens * head_dim + t * head_dim + d] =
597  q[t * n_heads * head_dim + h * head_dim + d];
598  }
599  }
600  }
601 
602  /* Reorder K: [T, H_kv*D] -> [H_kv, T, D] */
603  for (int t = 0; t < n_tokens; t++) {
604  for (int h = 0; h < n_heads_kv; h++) {
605  for (int d = 0; d < head_dim; d++) {
606  k_reorder[h * n_tokens * head_dim + t * head_dim + d] =
607  k[t * n_heads_kv * head_dim + h * head_dim + d];
608  }
609  }
610  }
611 
612  /* Apply RoPE */
613  rope_forward_qk(q_reorder, k_reorder,
614  cos_cache, sin_cache,
615  n_heads, n_heads_kv, n_tokens,
616  head_dim, head_dim, pos_offset);
617 
618  /* Reorder back: [H, T, D] -> [T, H*D] */
619  for (int t = 0; t < n_tokens; t++) {
620  for (int h = 0; h < n_heads; h++) {
621  for (int d = 0; d < head_dim; d++) {
622  q[t * n_heads * head_dim + h * head_dim + d] =
623  q_reorder[h * n_tokens * head_dim + t * head_dim + d];
624  }
625  }
626  }
627 
628  for (int t = 0; t < n_tokens; t++) {
629  for (int h = 0; h < n_heads_kv; h++) {
630  for (int d = 0; d < head_dim; d++) {
631  k[t * n_heads_kv * head_dim + h * head_dim + d] =
632  k_reorder[h * n_tokens * head_dim + t * head_dim + d];
633  }
634  }
635  }
636  }
637 
638  free(q_reorder);
639  free(k_reorder);
640  free(cos_cache);
641  free(sin_cache);
642 }
void rope_precompute_cache(float *cos_cache, float *sin_cache, int max_seq_len, int head_dim, float base)
Definition: rope_kernels.c:52
void rope_forward_qk(float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:448

References rope_forward_qk(), and rope_precompute_cache().

◆ ck_test_rope_interleaved()

void ck_test_rope_interleaved ( float *  q,
float *  k,
int  n_tokens,
int  n_heads,
int  n_heads_kv,
int  head_dim,
int  pos_offset,
float  theta 
)

RoPE with interleaved format (for llama.cpp compatibility)

Uses interleaved format: (x0, x1) -> (x0*cos - x1*sin, x0*sin + x1*cos)

Definition at line 644 of file ck_parity_api.c.

647 {
648  /* Interleaved RoPE format (matches llama.cpp):
649  * (x0, x1) -> (x0*cos - x1*sin, x0*sin + x1*cos)
650  * Applied to consecutive pairs of elements
651  */
652 
653  /* Precompute inverse frequencies */
654  float *inv_freq = (float *)malloc((head_dim / 2) * sizeof(float));
655  if (!inv_freq) return;
656 
657  for (int i = 0; i < head_dim / 2; i++) {
658  inv_freq[i] = 1.0f / powf(theta, (float)(2 * i) / head_dim);
659  }
660 
661  /* Apply RoPE to Q */
662  for (int t = 0; t < n_tokens; t++) {
663  int pos = pos_offset + t;
664  for (int h = 0; h < n_heads; h++) {
665  float *qh = q + t * n_heads * head_dim + h * head_dim;
666 
667  for (int i = 0; i < head_dim / 2; i++) {
668  float freq = pos * inv_freq[i];
669  float cos_val = cosf(freq);
670  float sin_val = sinf(freq);
671 
672  /* Interleaved format */
673  float x0 = qh[i * 2];
674  float x1 = qh[i * 2 + 1];
675  qh[i * 2] = x0 * cos_val - x1 * sin_val;
676  qh[i * 2 + 1] = x0 * sin_val + x1 * cos_val;
677  }
678  }
679  }
680 
681  /* Apply RoPE to K */
682  for (int t = 0; t < n_tokens; t++) {
683  int pos = pos_offset + t;
684  for (int h = 0; h < n_heads_kv; h++) {
685  float *kh = k + t * n_heads_kv * head_dim + h * head_dim;
686 
687  for (int i = 0; i < head_dim / 2; i++) {
688  float freq = pos * inv_freq[i];
689  float cos_val = cosf(freq);
690  float sin_val = sinf(freq);
691 
692  float x0 = kh[i * 2];
693  float x1 = kh[i * 2 + 1];
694  kh[i * 2] = x0 * cos_val - x1 * sin_val;
695  kh[i * 2 + 1] = x0 * sin_val + x1 * cos_val;
696  }
697  }
698  }
699 
700  free(inv_freq);
701 }

◆ ck_test_softmax()

void ck_test_softmax ( const float *  input,
float *  output,
int  n 
)

Softmax (simple, non-causal)

Computes: output[i] = exp(input[i]) / sum(exp(input))

Parameters
inputInput tensor [n]
outputOutput tensor [n]
nNumber of elements

Definition at line 710 of file ck_parity_api.c.

711 {
712  /* Find max for numerical stability */
713  float max_val = input[0];
714  for (int i = 1; i < n; i++) {
715  if (input[i] > max_val) max_val = input[i];
716  }
717 
718  /* Compute exp and sum */
719  float sum = 0.0f;
720  for (int i = 0; i < n; i++) {
721  output[i] = expf(input[i] - max_val);
722  sum += output[i];
723  }
724 
725  /* Normalize */
726  float inv_sum = 1.0f / sum;
727  for (int i = 0; i < n; i++) {
728  output[i] *= inv_sum;
729  }
730 }

◆ ck_test_swiglu()

void ck_test_swiglu ( const float *  gate_up,
float *  output,
int  n_tokens,
int  intermediate_dim 
)

SwiGLU activation.

Computes: output = SiLU(gate) * up where SiLU(x) = x * sigmoid(x)

Parameters
gate_upInput tensor [n_tokens, 2 * intermediate_dim] Layout: [gate_0..gate_D-1, up_0..up_D-1] per token
outputOutput tensor [n_tokens, intermediate_dim]
n_tokensNumber of tokens
intermediate_dimIntermediate dimension

Definition at line 703 of file ck_parity_api.c.

706 {
707  swiglu_forward(gate_up, output, n_tokens, intermediate_dim);
708 }
void swiglu_forward(const float *input, float *output, int tokens, int dim)

References swiglu_forward().

◆ ck_test_vec_dot_q5_0_q8_0()

void ck_test_vec_dot_q5_0_q8_0 ( const void *  weight_q5_0,
const void *  input_q8_0,
float *  output,
int  cols 
)

Direct Q5_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)

Direct Q5_0 x Q8_0 dot product (takes pre-quantized Q8_0 input)

This is a "direct" test that bypasses FP32-to-Q8_0 conversion. Useful for isolating kernel bugs from quantization bugs.

Parameters
weight_q5_0Q5_0 quantized weights [cols]
input_q8_0Q8_0 quantized input [cols] (pre-quantized!)
outputOutput scalar [1]
colsNumber of elements (must be multiple of 32)

Definition at line 364 of file ck_parity_api.c.

368 {
369  vec_dot_q5_0_q8_0(cols, output, weight_q5_0, input_q8_0);
370 }
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.

References vec_dot_q5_0_q8_0().

◆ ck_test_vec_dot_q8_0_q8_0()

void ck_test_vec_dot_q8_0_q8_0 ( const void *  weight_q8_0,
const void *  input_q8_0,
float *  output,
int  cols 
)

Direct Q8_0 x Q8_0 dot product test (takes pre-quantized Q8_0 input)

Direct Q8_0 x Q8_0 dot product (takes pre-quantized Q8_0 input)

Parameters
weight_q8_0Q8_0 quantized weights [cols]
input_q8_0Q8_0 quantized input [cols] (pre-quantized!)
outputOutput scalar [1]
colsNumber of elements (must be multiple of 32)

Definition at line 380 of file ck_parity_api.c.

384 {
385  vec_dot_q8_0_q8_0(cols, output, weight_q8_0, input_q8_0);
386 }
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.

References vec_dot_q8_0_q8_0().

◆ dequant_q4_0_row()

void dequant_q4_0_row ( const void *  src,
float *  dst,
size_t  n_elements 
)

Dequantize Q4_0 row (multiple blocks)

Parameters
srcQ4_0 data
dstFP32 output
n_elementsNumber of elements to dequantize

Definition at line 61 of file dequant_kernels.c.

62 {
63  const block_q4_0 *blocks = (const block_q4_0 *)src;
64  const size_t n_blocks = n_elements / QK4_0;
65 
66  for (size_t b = 0; b < n_blocks; b++) {
67  dequant_q4_0_block(&blocks[b], &dst[b * QK4_0]);
68  }
69 }
#define QK4_0
Definition: ckernel_quant.h:35
void dequant_q4_0_block(const block_q4_0 *block, float *output)
Dequantize a single Q4_0 block to FP32.

Referenced by ck_test_dequant_q4_0(), and dequant_row().

◆ dequant_q4_k_row()

void dequant_q4_k_row ( const void *  src,
float *  dst,
size_t  n_elements 
)

Dequantize Q4_K row (multiple blocks)

Definition at line 370 of file dequant_kernels.c.

371 {
372  const block_q4_K *blocks = (const block_q4_K *)src;
373  const size_t n_blocks = n_elements / QK_K;
374 
375  for (size_t b = 0; b < n_blocks; b++) {
376  dequant_q4_k_block(&blocks[b], &dst[b * QK_K]);
377  }
378 }
void dequant_q4_k_block(const block_q4_K *block, float *output)
Dequantize a single Q4_K block to FP32.

Referenced by ck_test_dequant_q4_k(), and dequant_row().

◆ dequant_q5_1_row()

void dequant_q5_1_row ( const void *  src,
float *  dst,
size_t  n_elements 
)

Dequantize Q5_1 row (multiple blocks)

Definition at line 255 of file dequant_kernels.c.

256 {
257  const block_q5_1 *blocks = (const block_q5_1 *)src;
258  const size_t n_blocks = n_elements / QK5_1;
259 
260  for (size_t b = 0; b < n_blocks; b++) {
261  dequant_q5_1_block(&blocks[b], &dst[b * QK5_1]);
262  }
263 }
void dequant_q5_1_block(const block_q5_1 *block, float *output)
Dequantize a single Q5_1 block to FP32.

Referenced by ck_test_dequant_q5_1(), and dequant_row().

◆ dequant_q6_k_row()

void dequant_q6_k_row ( const void *  src,
float *  dst,
size_t  n_elements 
)

Dequantize Q6_K row (multiple blocks)

Definition at line 420 of file dequant_kernels.c.

421 {
422  const block_q6_K *blocks = (const block_q6_K *)src;
423  const size_t n_blocks = n_elements / QK_K;
424 
425  for (size_t b = 0; b < n_blocks; b++) {
426  dequant_q6_k_block(&blocks[b], &dst[b * QK_K]);
427  }
428 }
void dequant_q6_k_block(const block_q6_K *block, float *output)
Dequantize a single Q6_K block to FP32.

Referenced by ck_test_dequant_q6_k(), ck_test_gemv_q6_k(), and dequant_row().

◆ geglu_backward_fp32()

void geglu_backward_fp32 ( const float *  x,
const float *  d_out,
float *  d_x,
int  tokens,
int  dim 
)

GeGLU backward pass (fp32)

Test:
test_geglu.py::TestGeGLU::test_geglu_backward_fp32

dL/dx given dL/d(out) where out = GELU(a) * b Chain rule: dL/da = dL/dout * d(GELU)/da * b dL/db = dL/dout * GELU(a)

After changes: make test

Definition at line 843 of file gelu_kernels.c.

848 {
849  const float sqrt_2_over_pi = 0.7978845608f;
850  const float coeff = 0.044715f;
851 
852  const int inner_dim = dim * 2;
853 
854  for (int t = 0; t < tokens; ++t) {
855  const float *x_ptr = x + (size_t)t * inner_dim;
856  const float *d_out_ptr = d_out + (size_t)t * dim;
857  float *d_x_ptr = d_x + (size_t)t * inner_dim;
858 
859  for (int d = 0; d < dim; ++d) {
860  float a = x_ptr[d];
861  float b = x_ptr[dim + d];
862  float dout = d_out_ptr[d];
863 
864  // GELU(a) derivative components
865  float a2 = a * a;
866  float a3 = a2 * a;
867  float g = sqrt_2_over_pi * (a + coeff * a3);
868  float tanh_g = tanhf(g);
869  float sech2_g = 1.0f - tanh_g * tanh_g;
870  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * a2);
871 
872  // d(GELU)/da = 0.5 * (1 + tanh(g)) + 0.5 * a * sech^2(g) * g'
873  float d_gelu = 0.5f * (1.0f + tanh_g) + 0.5f * a * sech2_g * g_prime;
874 
875  // dL/da = dL/dout * d(GELU)/da * b
876  d_x_ptr[d] = dout * d_gelu * b;
877 
878  // dL/db = dL/dout * GELU(a)
879  float gelu_a = 0.5f * a * (1.0f + tanh_g);
880  d_x_ptr[dim + d] = dout * gelu_a;
881  }
882  }
883 }

Referenced by ck_test_geglu_backward().

◆ geglu_forward_fp32()

void geglu_forward_fp32 ( const float *  x,
float *  out,
int  tokens,
int  dim 
)

GeGLU forward pass (fp32)

Test:
test_geglu.py::TestGeGLU::test_geglu_forward_fp32

Computes out = GELU(a) * b where x = [a, b] along last dimension. Input shape: [tokens, 2 * dim], Output shape: [tokens, dim]

After changes: make test

Definition at line 623 of file gelu_kernels.c.

624 {
625  const float sqrt_2_over_pi = 0.7978845608f;
626  const float coeff = 0.044715f;
627 
628  const int inner_dim = dim * 2;
629 
630 #if defined(__AVX512F__)
631  const __m512 sqrt_2_pi_vec = _mm512_set1_ps(sqrt_2_over_pi);
632  const __m512 coeff_vec = _mm512_set1_ps(coeff);
633  const __m512 half_vec = _mm512_set1_ps(0.5f);
634  const __m512 one_vec = _mm512_set1_ps(1.0f);
635 
636  for (int t = 0; t < tokens; ++t) {
637  const float *x_ptr = x + (size_t)t * inner_dim;
638  float *out_ptr = out + (size_t)t * dim;
639 
640  int d = 0;
641  // Process first half (a) with GELU, second half (b) directly
642  for (; d + 32 <= dim; d += 32) {
643  // Load a (first half of inner_dim)
644  __m512 a0 = _mm512_loadu_ps(&x_ptr[d]);
645  __m512 a1 = _mm512_loadu_ps(&x_ptr[d + 16]);
646 
647  // Compute GELU(a)
648  __m512 a0_sq = _mm512_mul_ps(a0, a0);
649  __m512 a0_cu = _mm512_mul_ps(a0_sq, a0);
650  __m512 a1_sq = _mm512_mul_ps(a1, a1);
651  __m512 a1_cu = _mm512_mul_ps(a1_sq, a1);
652 
653  // inner = sqrt(2/pi) * (a + 0.044715 * a^3)
654  __m512 inner0 = _mm512_fmadd_ps(coeff_vec, a0_cu, a0);
655  __m512 inner1 = _mm512_fmadd_ps(coeff_vec, a1_cu, a1);
656  inner0 = _mm512_mul_ps(sqrt_2_pi_vec, inner0);
657  inner1 = _mm512_mul_ps(sqrt_2_pi_vec, inner1);
658 
659  // tanh(inner)
660  __m512 tanh0 = tanh512_fast(inner0);
661  __m512 tanh1 = tanh512_fast(inner1);
662 
663  // GELU = 0.5 * a * (1 + tanh)
664  __m512 gelu0 = _mm512_mul_ps(half_vec, _mm512_mul_ps(a0, _mm512_add_ps(one_vec, tanh0)));
665  __m512 gelu1 = _mm512_mul_ps(half_vec, _mm512_mul_ps(a1, _mm512_add_ps(one_vec, tanh1)));
666 
667  // Load b (second half of inner_dim)
668  __m512 b0 = _mm512_loadu_ps(&x_ptr[dim + d]);
669  __m512 b1 = _mm512_loadu_ps(&x_ptr[dim + d + 16]);
670 
671  // out = GELU(a) * b
672  _mm512_storeu_ps(&out_ptr[d], _mm512_mul_ps(gelu0, b0));
673  _mm512_storeu_ps(&out_ptr[d + 16], _mm512_mul_ps(gelu1, b1));
674  }
675  // Handle remaining
676  for (; d < dim; ++d) {
677  float a = x_ptr[d];
678  float b = x_ptr[dim + d];
679  float a3 = a * a * a;
680  float inner = sqrt_2_over_pi * (a + coeff * a3);
681  float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
682  out_ptr[d] = gelu_a * b;
683  }
684  }
685 
686 #elif defined(__AVX2__)
687  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
688  const __m256 coeff_vec = _mm256_set1_ps(coeff);
689  const __m256 half_vec = _mm256_set1_ps(0.5f);
690  const __m256 one_vec = _mm256_set1_ps(1.0f);
691 
692  for (int t = 0; t < tokens; ++t) {
693  const float *x_ptr = x + (size_t)t * inner_dim;
694  float *out_ptr = out + (size_t)t * dim;
695 
696  int d = 0;
697  for (; d + 16 <= dim; d += 16) {
698  // Load a
699  __m256 a0 = _mm256_loadu_ps(&x_ptr[d]);
700  __m256 a1 = _mm256_loadu_ps(&x_ptr[d + 8]);
701 
702  // GELU(a)
703  __m256 a0_sq = _mm256_mul_ps(a0, a0);
704  __m256 a0_cu = _mm256_mul_ps(a0_sq, a0);
705  __m256 a1_sq = _mm256_mul_ps(a1, a1);
706  __m256 a1_cu = _mm256_mul_ps(a1_sq, a1);
707 
708  __m256 inner0 = _mm256_fmadd_ps(coeff_vec, a0_cu, a0);
709  __m256 inner1 = _mm256_fmadd_ps(coeff_vec, a1_cu, a1);
710  inner0 = _mm256_mul_ps(sqrt_2_pi_vec, inner0);
711  inner1 = _mm256_mul_ps(sqrt_2_pi_vec, inner1);
712 
713  __m256 tanh0 = tanh256_fast(inner0);
714  __m256 tanh1 = tanh256_fast(inner1);
715 
716  __m256 gelu0 = _mm256_mul_ps(half_vec, _mm256_mul_ps(a0, _mm256_add_ps(one_vec, tanh0)));
717  __m256 gelu1 = _mm256_mul_ps(half_vec, _mm256_mul_ps(a1, _mm256_add_ps(one_vec, tanh1)));
718 
719  // b
720  __m256 b0 = _mm256_loadu_ps(&x_ptr[dim + d]);
721  __m256 b1 = _mm256_loadu_ps(&x_ptr[dim + d + 8]);
722 
723  _mm256_storeu_ps(&out_ptr[d], _mm256_mul_ps(gelu0, b0));
724  _mm256_storeu_ps(&out_ptr[d + 8], _mm256_mul_ps(gelu1, b1));
725  }
726  for (; d < dim; ++d) {
727  float a = x_ptr[d];
728  float b = x_ptr[dim + d];
729  float a3 = a * a * a;
730  float inner = sqrt_2_over_pi * (a + coeff * a3);
731  float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
732  out_ptr[d] = gelu_a * b;
733  }
734  }
735 
736 #elif defined(__AVX__)
737  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
738  const __m256 coeff_vec = _mm256_set1_ps(coeff);
739  const __m256 half_vec = _mm256_set1_ps(0.5f);
740  const __m256 one_vec = _mm256_set1_ps(1.0f);
741 
742  float inner_arr[8] __attribute__((aligned(32)));
743  float tanh_arr[8] __attribute__((aligned(32)));
744 
745  for (int t = 0; t < tokens; ++t) {
746  const float *x_ptr = x + (size_t)t * inner_dim;
747  float *out_ptr = out + (size_t)t * dim;
748 
749  int d = 0;
750  for (; d + 8 <= dim; d += 8) {
751  __m256 a = _mm256_loadu_ps(&x_ptr[d]);
752  __m256 a_sq = _mm256_mul_ps(a, a);
753  __m256 a_cu = _mm256_mul_ps(a_sq, a);
754 
755  __m256 coeff_a_cu = _mm256_mul_ps(coeff_vec, a_cu);
756  __m256 inner = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(a, coeff_a_cu));
757 
758  _mm256_store_ps(inner_arr, inner);
759  for (int j = 0; j < 8; ++j) {
760  tanh_arr[j] = tanhf(inner_arr[j]);
761  }
762  __m256 tanh_val = _mm256_load_ps(tanh_arr);
763 
764  __m256 gelu = _mm256_mul_ps(half_vec, _mm256_mul_ps(a, _mm256_add_ps(one_vec, tanh_val)));
765  __m256 b = _mm256_loadu_ps(&x_ptr[dim + d]);
766 
767  _mm256_storeu_ps(&out_ptr[d], _mm256_mul_ps(gelu, b));
768  }
769  for (; d < dim; ++d) {
770  float a = x_ptr[d];
771  float b = x_ptr[dim + d];
772  float a3 = a * a * a;
773  float inner = sqrt_2_over_pi * (a + coeff * a3);
774  float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
775  out_ptr[d] = gelu_a * b;
776  }
777  }
778 
779 #else
780  // Scalar fallback
781  for (int t = 0; t < tokens; ++t) {
782  const float *x_ptr = x + (size_t)t * inner_dim;
783  float *out_ptr = out + (size_t)t * dim;
784 
785  for (int d = 0; d < dim; ++d) {
786  float a = x_ptr[d];
787  float b = x_ptr[dim + d];
788  float a3 = a * a * a;
789  float inner = sqrt_2_over_pi * (a + coeff * a3);
790  float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
791  out_ptr[d] = gelu_a * b;
792  }
793  }
794 #endif
795 }
__attribute__((visibility("default"))) CKTokenizer *ck_tokenizer_create(CKTokenizerType type)

Referenced by ck_test_geglu(), and geglu_forward_bf16().

◆ gemm_nt_q4_k_q8_k()

void gemm_nt_q4_k_q8_k ( const void *  A_q8,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 295 of file gemm_kernels_q4k_q8k.c.

300 {
301  if (!A_q8 || !B || !C) {
302  return;
303  }
304  if (M <= 0 || N <= 0 || K <= 0) {
305  return;
306  }
307 
308  gemm_q4_k_q8_k(C, B, A_q8, /*M_out=*/N, /*N_batch=*/M, K);
309 
310  if (!bias) {
311  return;
312  }
313 
314  for (int i = 0; i < M; ++i) {
315  float *row = C + (size_t)i * (size_t)N;
316  for (int j = 0; j < N; ++j) {
317  row[j] += bias[j];
318  }
319  }
320 }
void gemm_q4_k_q8_k(float *Y, const void *W, const void *X_q8, int M, int N, int K)
#define C(color)
Definition: show_config.c:39

Referenced by ck_test_gemm_q4_k().

◆ gemm_nt_q5_0_q8_0()

void gemm_nt_q5_0_q8_0 ( const void *  A_q8,
const void *  B_q5,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.

Computes C = A @ B^T + bias where: A: [M x K] Q8_0 quantized activations (M tokens, K features) B: [N x K] Q5_0 quantized weights (N outputs, K features) C: [M x N] FP32 output

This is the INT8 batch kernel for prefill, using pre-quantized activations to avoid FP32->Q8_0 conversion overhead per operation.

Parameters
A_q8Input activations in Q8_0 format [M rows of K/32 blocks each]
B_q5Weights in Q5_0 format [N rows of K/32 blocks each]
biasOptional bias vector [N], NULL if not used
COutput matrix [M x N], row-major FP32
MBatch size (number of tokens)
NOutput dimension (number of output features)
KInput dimension (must be multiple of 32)

Definition at line 1617 of file gemm_kernels_q5_0.c.

1625 {
1626  const block_q5_0 *weights = (const block_q5_0 *)B_q5;
1627  const block_q8_0 *inputs = (const block_q8_0 *)A_q8;
1628  const int blocks_per_row = K / QK5_0;
1629 
1630  for (int m = 0; m < M; m++) {
1631  const block_q8_0 *input_row = &inputs[m * blocks_per_row];
1632 
1633  for (int n = 0; n < N; n++) {
1634  const block_q5_0 *weight_row = &weights[n * blocks_per_row];
1635  float *out = &C[m * N + n];
1636 
1637  /* Dispatches to vec_dot_q5_0_q8_0_avx (2x block unrolled) on AVX */
1638  vec_dot_q5_0_q8_0(K, out, weight_row, input_row);
1639 
1640  if (bias) {
1641  *out += bias[n];
1642  }
1643  }
1644  }
1645 }
#define QK5_0
Definition: ckernel_quant.h:67
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.

Referenced by ck_test_gemm_q5_0().

◆ gemm_nt_q5_1()

void gemm_nt_q5_1 ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

GEMM with transposed Q5_1 weights: C = A @ B^T.

Parameters
AInput activations [M x K], row-major FP32
BWeight matrix in Q5_1 format [N x K], row-major quantized
biasOptional bias [N], NULL if not used
COutput [M x N], row-major FP32
MBatch size (number of tokens)
NOutput dimension
KInput dimension

Definition at line 309 of file gemm_kernels_q5_1.c.

314 {
315  const block_q5_1 *blocks = (const block_q5_1 *)B;
316  const int blocks_per_row = K / QK5_1;
317 
318  for (int m = 0; m < M; m++) {
319  const float *a_row = &A[m * K];
320 
321  for (int n = 0; n < N; n++) {
322  float sum = 0.0f;
323 
324  for (int b = 0; b < blocks_per_row; b++) {
325  const block_q5_1 *block = &blocks[n * blocks_per_row + b];
326  const float d = CK_FP16_TO_FP32(block->d);
327  const float min = CK_FP16_TO_FP32(block->m);
328  const float *ap = &a_row[b * QK5_1];
329 
330  uint32_t qh;
331  memcpy(&qh, block->qh, sizeof(qh));
332 
333  /* First 16 weights: low nibbles, high bits from qh[0:15] */
334  for (int j = 0; j < QK5_1 / 2; j++) {
335  const int lo = (block->qs[j] & 0x0F);
336  const int hi = ((qh >> j) & 1) << 4;
337  sum += (d * (float)(lo | hi) + min) * ap[j];
338  }
339 
340  /* Second 16 weights: high nibbles, high bits from qh[16:31] */
341  for (int j = 0; j < QK5_1 / 2; j++) {
342  const int lo = (block->qs[j] >> 4);
343  const int hi = ((qh >> (j + 16)) & 1) << 4;
344  sum += (d * (float)(lo | hi) + min) * ap[j + QK5_1 / 2];
345  }
346  }
347 
348  C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
349  }
350  }
351 }
#define CK_FP16_TO_FP32(x)
uint8_t qs[32/2]
Definition: ckernel_quant.h:90
uint8_t qh[4]
Definition: ckernel_quant.h:89
ck_half m
Definition: ckernel_quant.h:88
ck_half d
Definition: ckernel_quant.h:87

Referenced by ck_test_gemm_q5_1().

◆ gemm_nt_q5_k()

void gemm_nt_q5_k ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 218 of file gemm_kernels_q5_k.c.

223 {
224 #if defined(__AVX512F__)
225  /* TODO: AVX-512 implementation */
226  gemm_nt_q5_k_ref(A, B, bias, C, M, N, K);
227 #elif defined(__AVX2__)
228  /* TODO: AVX-2 implementation */
229  gemm_nt_q5_k_ref(A, B, bias, C, M, N, K);
230 #elif defined(__AVX__)
231  /* TODO: AVX implementation */
232  gemm_nt_q5_k_ref(A, B, bias, C, M, N, K);
233 #elif defined(__SSE4_1__)
234  /* TODO: SSE4.1 implementation */
235  gemm_nt_q5_k_ref(A, B, bias, C, M, N, K);
236 #else
237  gemm_nt_q5_k_ref(A, B, bias, C, M, N, K);
238 #endif
239 }
void gemm_nt_q5_k_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)

Referenced by ck_test_gemm_q5_k().

◆ gemm_nt_q6_k_q8_k()

void gemm_nt_q6_k_q8_k ( const void *  A_q8,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.

This is the typical inference pattern:

  • A: Activations in Q8_K format [M x K]
  • B: Weights in Q6_K format [N x K]
  • C: Output [M x N]
Parameters
A_q8Input activations in Q8_K format
BWeight matrix in Q6_K format
biasOptional bias vector [N]
COutput matrix
MBatch size (number of tokens)
NOutput dimension
KInput dimension

Definition at line 1144 of file gemm_kernels_q6k_q8k.c.

1149 {
1150  if (!A_q8 || !B || !C) {
1151  return;
1152  }
1153  if (M <= 0 || N <= 0 || K <= 0) {
1154  return;
1155  }
1156 
1157  gemm_q6_k_q8_k(C, B, A_q8, /*M_out=*/N, /*N_batch=*/M, K);
1158 
1159  if (!bias) {
1160  return;
1161  }
1162 
1163  for (int i = 0; i < M; ++i) {
1164  float *row = C + (size_t)i * (size_t)N;
1165  for (int j = 0; j < N; ++j) {
1166  row[j] += bias[j];
1167  }
1168  }
1169 }
void gemm_q6_k_q8_k(float *Y, const void *W, const void *X_q8, int M, int N, int K)
GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K.

Referenced by ck_test_gemm_q6_k().

◆ gemm_nt_q8_0_q8_0()

void gemm_nt_q8_0_q8_0 ( const void *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

gemm_nt_q8_0_q8_0 with optional bias (matches header signature)

C[m,n] = A[m,K] @ B[n,K]^T + bias[n]

Definition at line 582 of file gemm_batch_int8.c.

588 {
589  /* First compute GEMM */
590 #if defined(__AVX512VNNI__)
591  gemm_nt_q8_0_q8_0_vnni(A, B, C, M, N, K);
592 #elif defined(__AVX512F__)
593  gemm_nt_q8_0_q8_0_avx512(A, B, C, M, N, K);
594 #elif defined(__AVX2__)
595  gemm_nt_q8_0_q8_0_avx2(A, B, C, M, N, K);
596 #elif defined(__AVX__)
597  gemm_nt_q8_0_q8_0_avx(A, B, C, M, N, K);
598 #else
599  gemm_nt_q8_0_q8_0_ref(A, B, C, M, N, K);
600 #endif
601 
602  /* Add bias if provided */
603  if (bias != NULL) {
604  for (int m = 0; m < M; m++) {
605  for (int n = 0; n < N; n++) {
606  C[(size_t)m * N + n] += bias[n];
607  }
608  }
609  }
610 }
void gemm_nt_q8_0_q8_0_ref(const void *A, const void *B, float *C, int M, int N, int K)
Scalar reference: gemm_nt_q8_0_q8_0.

Referenced by ck_test_gemm_q8_0().

◆ gemv_q4_k_q8_k()

void gemv_q4_k_q8_k ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Definition at line 239 of file gemm_kernels_q4k_q8k.c.

243 {
244 #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
245  /* VNNI: Best for decode (single token) - INT8 dot product acceleration */
246  gemv_q4_k_q8_k_vnni(y, W, x_q8, M, K);
247 #elif defined(__AVX2__)
248  gemv_q4_k_q8_k_avx2(y, W, x_q8, M, K);
249 #elif defined(__AVX__)
250  /* AVX version uses maddubs_epi16 (more efficient than SSE) */
251  gemv_q4_k_q8_k_avx(y, W, x_q8, M, K);
252 #elif defined(__SSE4_1__)
253  gemv_q4_k_q8_k_sse(y, W, x_q8, M, K);
254 #else
255  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
256 #endif
257 }
void gemv_q4_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_vnni(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K)

Referenced by ck_test_gemv_q4_k().

◆ gemv_q5_0()

void gemv_q5_0 ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Auto-dispatch GEMV for Q5_0 weights based on CPU features.

Dispatch priority (best available):

  1. AVX-512 (512-bit vectors) - Intel Skylake-X+
  2. AVX2+FMA (256-bit vectors) - Intel Haswell+
  3. AVX (256-bit vectors) - Intel Sandy Bridge+
  4. SSE4.1 (128-bit vectors) - Intel Nehalem+
  5. Reference (scalar) - Fallback

Uses ck_features.h for standardized feature detection.

Parameters
yOutput vector [M]
WWeight matrix in Q5_0 format [M x K]
xInput vector [K]
MNumber of output rows
KNumber of input columns (hidden dimension)

Definition at line 547 of file gemm_kernels_q5_0.c.

551 {
552 // Dispatch order: AVX512 > AVX2 > AVX > SSE > ref
553 #if defined(__AVX512F__)
554  gemv_q5_0_avx512(y, W, x, M, K);
555 #elif defined(__AVX2__)
556  gemv_q5_0_avx2(y, W, x, M, K);
557 #elif defined(__AVX__)
558  gemv_q5_0_avx(y, W, x, M, K);
559 #elif defined(__SSE4_1__)
560  gemv_q5_0_sse_v2(y, W, x, M, K);
561 #else
562  gemv_q5_0_ref(y, W, x, M, K);
563 #endif
564 }
void gemv_q5_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q5_0 weights (scalar reference)

Referenced by dot_q5_0(), gemm_nt_q5_0(), and gemm_q5_0().

◆ gemv_q5_0_q8_0()

void gemv_q5_0_q8_0 ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Matrix-vector multiply with Q5_0 weights and Q8_0 input.

Parameters
yOutput vector [M]
WWeight matrix in Q5_0 format [M x K]
x_q8Input vector in Q8_0 format [K]
MNumber of output rows
KNumber of columns (must be multiple of 32)

Definition at line 1529 of file gemm_kernels_q5_0.c.

1533 {
1534  const block_q5_0 *w_blocks = (const block_q5_0 *)W;
1535  const block_q8_0 *x_blocks = (const block_q8_0 *)x_q8;
1536  const int blocks_per_row = K / QK5_0;
1537 
1538  for (int row = 0; row < M; row++) {
1539  vec_dot_q5_0_q8_0(K, &y[row],
1540  &w_blocks[row * blocks_per_row],
1541  x_blocks);
1542  }
1543 }

Referenced by ck_test_gemv_q5_0(), and ck_test_gemv_q5_0_q8_0().

◆ gemv_q5_1()

void gemv_q5_1 ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Auto-dispatch GEMV.

Definition at line 184 of file gemm_kernels_q5_1.c.

188 {
189 #ifdef __AVX512F__
190  gemv_q5_1_avx512(y, W, x, M, K);
191 #else
192  gemv_q5_1_ref(y, W, x, M, K);
193 #endif
194 }
void gemv_q5_1_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q5_1 weights (scalar reference)

Referenced by ck_test_gemv_q5_1(), dot_q5_1(), and gemm_q5_1().

◆ gemv_q5_k()

void gemv_q5_k ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Definition at line 199 of file gemm_kernels_q5_k.c.

200 {
201 #if defined(__AVX512F__)
202  /* TODO: AVX-512 implementation */
203  gemv_q5_k_ref(y, W, x, M, K);
204 #elif defined(__AVX2__)
205  /* TODO: AVX-2 implementation */
206  gemv_q5_k_ref(y, W, x, M, K);
207 #elif defined(__AVX__)
208  /* TODO: AVX implementation */
209  gemv_q5_k_ref(y, W, x, M, K);
210 #elif defined(__SSE4_1__)
211  /* TODO: SSE4.1 implementation */
212  gemv_q5_k_ref(y, W, x, M, K);
213 #else
214  gemv_q5_k_ref(y, W, x, M, K);
215 #endif
216 }
void gemv_q5_k_ref(float *y, const void *W, const float *x, int M, int K)

Referenced by ck_test_gemv_q5_k().

◆ gemv_q6_k_q8_k()

void gemv_q6_k_q8_k ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

GEMV: y = W @ x where W is Q6_K and x is Q8_K.

Definition at line 980 of file gemm_kernels_q6k_q8k.c.

984 {
985  /* AVX-512 uses same algorithm as AVX2 (matches llama.cpp) */
986 #if defined(__AVX512F__) && defined(__AVX512BW__)
987  gemv_q6_k_q8_k_avx512(y, W, x_q8, M, K);
988 #elif defined(__AVX2__)
989  gemv_q6_k_q8_k_avx2(y, W, x_q8, M, K);
990 #elif defined(__AVX__)
991  gemv_q6_k_q8_k_avx(y, W, x_q8, M, K);
992 #elif defined(__SSSE3__)
993  gemv_q6_k_q8_k_sse(y, W, x_q8, M, K);
994 #else
995  gemv_q6_k_q8_k_ref(y, W, x_q8, M, K);
996 #endif
997 }
void gemv_q6_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_avx512(float *y, const void *W, const void *x_q8, int M, int K)

Referenced by gemm_q6_k_q8_k().

◆ gemv_q8_0()

void gemv_q8_0 ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Auto-dispatch GEMV for Q8_0 weights based on CPU features.

Dispatch priority (best available):

  1. AVX-512 (512-bit vectors) - Intel Skylake-X+
  2. AVX2+FMA (256-bit vectors) - Intel Haswell+
  3. AVX (256-bit vectors) - Intel Sandy Bridge+
  4. SSE4.1 (128-bit vectors) - Intel Nehalem+
  5. Reference (scalar) - Fallback

Uses ck_features.h for standardized feature detection.

Parameters
yOutput vector [M]
WWeight matrix in Q8_0 format [M x K]
xInput vector [K]
MNumber of output rows
KNumber of input columns (hidden dimension)

Definition at line 630 of file gemm_kernels_q8_0.c.

634 {
635 // Dispatch order: AVX512 > AVX2 > AVX > SSE > ref
636 #if defined(__AVX512F__)
637  gemv_q8_0_avx512(y, W, x, M, K);
638 #elif defined(__AVX2__)
639  gemv_q8_0_avx2(y, W, x, M, K);
640 #elif defined(__AVX__)
641  gemv_q8_0_avx(y, W, x, M, K);
642 #elif defined(__SSE4_1__)
643  gemv_q8_0_sse(y, W, x, M, K);
644 #else
645  gemv_q8_0_ref(y, W, x, M, K);
646 #endif
647 }
void gemv_q8_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q8_0 weights (scalar reference)

Referenced by dot_q8_0(), gemm_nt_q8_0(), and gemm_q8_0().

◆ gemv_q8_0_q8_0()

void gemv_q8_0_q8_0 ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Matrix-vector multiply with Q8_0 weights and Q8_0 input.

Parameters
yOutput vector [M]
WWeight matrix in Q8_0 format [M x K]
x_q8Input vector in Q8_0 format [K]
MNumber of output rows
KNumber of columns (must be multiple of 32)

Definition at line 1042 of file gemm_kernels_q8_0.c.

1046 {
1047  const block_q8_0 *w_blocks = (const block_q8_0 *)W;
1048  const block_q8_0 *x_blocks = (const block_q8_0 *)x_q8;
1049  const int blocks_per_row = K / QK8_0;
1050 
1051  for (int row = 0; row < M; row++) {
1052  vec_dot_q8_0_q8_0(K, &y[row],
1053  &w_blocks[row * blocks_per_row],
1054  x_blocks);
1055  }
1056 }
#define QK8_0
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.

Referenced by ck_test_gemv_q8_0(), and ck_test_gemv_q8_0_q8_0().

◆ mega_fused_outproj_mlp_prefill()

void mega_fused_outproj_mlp_prefill ( float *  output,
const float *  attn_out,
const float *  residual,
const float *  ln2_gamma,
const void *  wo,
const float *  bo,
int  wo_dt,
const void *  w1,
const float *  b1,
int  w1_dt,
const void *  w2,
const float *  b2,
int  w2_dt,
int  tokens,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim,
int  intermediate_dim,
int  aligned_intermediate_dim,
float  eps,
void *  scratch 
)

◆ mega_fused_outproj_mlp_prefill_scratch_size()

size_t mega_fused_outproj_mlp_prefill_scratch_size ( int  tokens,
int  aligned_embed_dim,
int  num_heads,
int  aligned_head_dim,
int  aligned_intermediate_dim 
)

Get scratch buffer size for mega_fused_outproj_mlp_prefill.

Definition at line 159 of file mega_fused_outproj_mlp_prefill.c.

164 {
165  if (tokens <= 0 || aligned_embed_dim <= 0 || num_heads <= 0 ||
166  aligned_head_dim <= 0 || aligned_intermediate_dim <= 0) {
167  return 0;
168  }
169 
170  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0,
171  (size_t)aligned_head_dim);
172  const size_t attn_q8_bytes = (size_t)num_heads * (size_t)tokens * q8_row_bytes;
173  const size_t h1_bytes = (size_t)tokens * (size_t)aligned_embed_dim * sizeof(float);
174  const size_t ln2_bytes = h1_bytes;
175  const size_t mlp_scratch = fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(
176  aligned_embed_dim, aligned_intermediate_dim);
177 
178  return align_up_size(attn_q8_bytes, 64) +
179  align_up_size(h1_bytes, 64) +
180  align_up_size(ln2_bytes, 64) +
181  align_up_size(mlp_scratch, 64);
182 }
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(int aligned_embed_dim, int aligned_intermediate_dim)
Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.
static size_t align_up_size(size_t value, size_t align)

Referenced by ck_test_outproj_mlp_fused_q5_0().

◆ quantize_row_q8_0()

void quantize_row_q8_0 ( const float *  x,
void *  vy,
int  k 
)

Quantize FP32 to Q8_0 format (scalar reference)

Parameters
xInput FP32 values
vyOutput Q8_0 blocks
kNumber of elements (must be multiple of 32)

Definition at line 59 of file gemm_kernels_q8_0.c.

60 {
61  block_q8_0 *y = (block_q8_0 *)vy;
62  const int nb = k / QK8_0; /* QK8_0 = 32 */
63 
64 #if defined(__AVX__)
65  const __m256 sign_bit = _mm256_set1_ps(-0.0f);
66  const __m256 v_half = _mm256_set1_ps(0.5f);
67  const __m256 v_min = _mm256_set1_ps(-127.0f);
68  const __m256 v_max = _mm256_set1_ps(127.0f);
69 
70  for (int i = 0; i < nb; i++) {
71  __m256 v0 = _mm256_loadu_ps(x + 0);
72  __m256 v1 = _mm256_loadu_ps(x + 8);
73  __m256 v2 = _mm256_loadu_ps(x + 16);
74  __m256 v3 = _mm256_loadu_ps(x + 24);
75  x += QK8_0;
76 
77  __m256 max_abs = _mm256_andnot_ps(sign_bit, v0);
78  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v1));
79  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v2));
80  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v3));
81 
82  __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max_abs, 1),
83  _mm256_castps256_ps128(max_abs));
84  max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
85  max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
86  const float max_scalar = _mm_cvtss_f32(max4);
87 
88  const float d = max_scalar / 127.0f;
89  const float id = max_scalar != 0.0f ? 127.0f / max_scalar : 0.0f;
90  y[i].d = CK_FP32_TO_FP16(d);
91 
92  const __m256 mul = _mm256_set1_ps(id);
93  v0 = _mm256_mul_ps(v0, mul);
94  v1 = _mm256_mul_ps(v1, mul);
95  v2 = _mm256_mul_ps(v2, mul);
96  v3 = _mm256_mul_ps(v3, mul);
97 
98  v0 = _mm256_min_ps(_mm256_max_ps(v0, v_min), v_max);
99  v1 = _mm256_min_ps(_mm256_max_ps(v1, v_min), v_max);
100  v2 = _mm256_min_ps(_mm256_max_ps(v2, v_min), v_max);
101  v3 = _mm256_min_ps(_mm256_max_ps(v3, v_min), v_max);
102 
103  /* Round half away from zero to match the scalar path */
104  v0 = _mm256_add_ps(v0, _mm256_or_ps(_mm256_and_ps(v0, sign_bit), v_half));
105  v1 = _mm256_add_ps(v1, _mm256_or_ps(_mm256_and_ps(v1, sign_bit), v_half));
106  v2 = _mm256_add_ps(v2, _mm256_or_ps(_mm256_and_ps(v2, sign_bit), v_half));
107  v3 = _mm256_add_ps(v3, _mm256_or_ps(_mm256_and_ps(v3, sign_bit), v_half));
108 
109  __m256i i0 = _mm256_cvttps_epi32(v0);
110  __m256i i1 = _mm256_cvttps_epi32(v1);
111  __m256i i2 = _mm256_cvttps_epi32(v2);
112  __m256i i3 = _mm256_cvttps_epi32(v3);
113 
114 #if defined(__AVX2__)
115  i0 = _mm256_packs_epi32(i0, i1);
116  i2 = _mm256_packs_epi32(i2, i3);
117  i0 = _mm256_packs_epi16(i0, i2);
118 
119  const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
120  i0 = _mm256_permutevar8x32_epi32(i0, perm);
121  _mm256_storeu_si256((__m256i *)y[i].qs, i0);
122 #else
123  __m128i ni0 = _mm256_castsi256_si128(i0);
124  __m128i ni1 = _mm256_extractf128_si256(i0, 1);
125  __m128i ni2 = _mm256_castsi256_si128(i1);
126  __m128i ni3 = _mm256_extractf128_si256(i1, 1);
127  __m128i ni4 = _mm256_castsi256_si128(i2);
128  __m128i ni5 = _mm256_extractf128_si256(i2, 1);
129  __m128i ni6 = _mm256_castsi256_si128(i3);
130  __m128i ni7 = _mm256_extractf128_si256(i3, 1);
131 
132  ni0 = _mm_packs_epi32(ni0, ni1);
133  ni2 = _mm_packs_epi32(ni2, ni3);
134  ni4 = _mm_packs_epi32(ni4, ni5);
135  ni6 = _mm_packs_epi32(ni6, ni7);
136 
137  ni0 = _mm_packs_epi16(ni0, ni2);
138  ni4 = _mm_packs_epi16(ni4, ni6);
139 
140  _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
141  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
142 #endif
143  }
144 #else
145  for (int i = 0; i < nb; i++) {
146  const float *xb = x + i * QK8_0;
147 
148  /* Find max absolute value in block */
149  float amax = 0.0f;
150  for (int j = 0; j < QK8_0; j++) {
151  float av = xb[j] >= 0 ? xb[j] : -xb[j];
152  if (av > amax) amax = av;
153  }
154 
155  /* Compute scale: d = max / 127 */
156  float d = amax / 127.0f;
157  float id = d != 0.0f ? 127.0f / amax : 0.0f;
158 
159  /* Store scale as FP16 */
160  y[i].d = CK_FP32_TO_FP16(d);
161 
162  /* Quantize values */
163  for (int j = 0; j < QK8_0; j++) {
164  float v = xb[j] * id;
165  /* Round to nearest int and clamp to [-127, 127] */
166  int q = (int)(v + (v >= 0 ? 0.5f : -0.5f));
167  if (q > 127) q = 127;
168  if (q < -127) q = -127;
169  y[i].qs[j] = (int8_t)q;
170  }
171  }
172 #endif
173 }
#define CK_FP32_TO_FP16(x)
int8_t qs[32]
int32_t id
Definition: tokenizer.h:315

Referenced by ck_test_gemm_q5_0(), ck_test_gemm_q8_0(), ck_test_gemv_q5_0(), ck_test_gemv_q5_0_q8_0(), ck_test_gemv_q8_0(), and ck_test_gemv_q8_0_q8_0().

◆ quantize_row_q8_k()

void quantize_row_q8_k ( const float *  x,
void *  vy,
int  k 
)

Definition at line 107 of file gemm_kernels_q4k_q8k.c.

107  {
108 #if defined(__SSE4_1__)
109  quantize_row_q8_k_sse(x, vy, k);
110 #else
111  quantize_row_q8_k_ref(x, vy, k);
112 #endif
113 }
void quantize_row_q8_k_sse(const float *x, void *vy, int k)
void quantize_row_q8_k_ref(const float *x, void *vy, int k)

Referenced by ck_test_gemm_q4_k(), ck_test_gemm_q6_k(), ck_test_gemv_q4_k(), and ck_test_quantize_q8_k().

◆ rmsnorm_forward()

void rmsnorm_forward ( const float *  input,
const float *  gamma,
float *  output,
float *  rstd_cache,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps 
)

RMSNorm forward pass

Test:

test_rmsnorm.py::TestRMSNormForward::test_fp32_tokens

test_rmsnorm.py::TestRMSNormForward::test_fp32_single

test_rmsnorm.py::TestRMSNormForward::test_perf_rolled

test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat

test_parity.py::test_rmsnorm_parity

RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)

After changes: make test && make llamacpp-parity-full

Definition at line 50 of file rmsnorm_kernels.c.

58 {
59  int T = tokens;
60  int D = d_model;
61  int aligned = aligned_embed_dim;
62 
63  for (int t = 0; t < T; ++t) {
64  const float *x = input + (size_t)t * aligned;
65  float *y = output + (size_t)t * aligned;
66 
67 #if defined(__AVX512F__)
68  // AVX-512: Process 16 floats at a time
69  __m512 sum_sq_vec = _mm512_setzero_ps();
70  int d = 0;
71 
72  // Vectorized sum of squares
73  for (; d + 16 <= D; d += 16) {
74  __m512 xv = _mm512_loadu_ps(&x[d]);
75  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
76  }
77  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
78 
79  // Handle remaining elements
80  for (; d < D; ++d) {
81  sum_sq += x[d] * x[d];
82  }
83 
84  float mean_sq = sum_sq / (float)D;
85  float rstd = 1.0f / sqrtf(mean_sq + eps);
86  if (rstd_cache) {
87  rstd_cache[t] = rstd;
88  }
89 
90  // Apply normalization and scale (vectorized)
91  __m512 rstd_vec = _mm512_set1_ps(rstd);
92  d = 0;
93  for (; d + 16 <= D; d += 16) {
94  __m512 xv = _mm512_loadu_ps(&x[d]);
95  __m512 gv = _mm512_loadu_ps(&gamma[d]);
96  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
97  __m512 yv = _mm512_mul_ps(x_hat, gv);
98  _mm512_storeu_ps(&y[d], yv);
99  }
100  // Handle remaining elements
101  for (; d < D; ++d) {
102  y[d] = x[d] * rstd * gamma[d];
103  }
104 
105 #elif defined(__AVX__)
106  // AVX: Process 8 floats at a time
107  __m256 sum_sq_vec = _mm256_setzero_ps();
108  int d = 0;
109 
110  // Vectorized sum of squares (no FMA in AVX1, use mul + add)
111  for (; d + 8 <= D; d += 8) {
112  __m256 xv = _mm256_loadu_ps(&x[d]);
113  __m256 xv_sq = _mm256_mul_ps(xv, xv);
114  sum_sq_vec = _mm256_add_ps(sum_sq_vec, xv_sq);
115  }
116  float sum_sq = hsum256_ps_rmsnorm(sum_sq_vec);
117 
118  // Handle remaining elements
119  for (; d < D; ++d) {
120  sum_sq += x[d] * x[d];
121  }
122 
123  float mean_sq = sum_sq / (float)D;
124  float rstd = 1.0f / sqrtf(mean_sq + eps);
125  if (rstd_cache) {
126  rstd_cache[t] = rstd;
127  }
128 
129  // Apply normalization and scale (vectorized)
130  __m256 rstd_vec = _mm256_set1_ps(rstd);
131  d = 0;
132  for (; d + 8 <= D; d += 8) {
133  __m256 xv = _mm256_loadu_ps(&x[d]);
134  __m256 gv = _mm256_loadu_ps(&gamma[d]);
135  __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
136  __m256 yv = _mm256_mul_ps(x_hat, gv);
137  _mm256_storeu_ps(&y[d], yv);
138  }
139  // Handle remaining elements
140  for (; d < D; ++d) {
141  y[d] = x[d] * rstd * gamma[d];
142  }
143 
144 #else
145  // Scalar fallback
146  double sum_sq = 0.0;
147  for (int d = 0; d < D; ++d) {
148  double v = (double)x[d];
149  sum_sq += v * v;
150  }
151  double mean_sq = sum_sq / (double)D;
152  double r = sqrt(mean_sq + (double)eps);
153  float rstd = (float)(1.0 / r);
154  if (rstd_cache) {
155  rstd_cache[t] = rstd;
156  }
157 
158  // Apply normalization and scale
159  for (int d = 0; d < D; ++d) {
160  float x_hat = x[d] * rstd;
161  y[d] = x_hat * gamma[d];
162  }
163 #endif
164 
165  // Zero padding (if any)
166  for (int d = D; d < aligned; ++d) {
167  y[d] = 0.0f;
168  }
169  }
170 }

Referenced by ck_test_rmsnorm().

◆ rope_forward_qk()

void rope_forward_qk ( float *  q,
float *  k,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset 
)

RoPE forward for both Q and K (common inference pattern)

Test:

test_rope.py::TestRoPEForward::test_rope_forward_qk

test_fused_attention_decode.py::TestFusedAttentionDecode::test_qk_rope

test_parity.py::test_rope_qk_parity

Combined RoPE forward for both Q and K in one call. q: [num_heads, num_tokens, head_dim] k: [num_kv_heads, num_tokens, head_dim]

After changes: make test && make llamacpp-parity-full

Definition at line 448 of file rope_kernels.c.

458 {
459  rope_forward(q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
460  rope_forward(k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
461 }
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:180

Referenced by ck_test_rope().

◆ rope_precompute_cache()

void rope_precompute_cache ( float *  cos_cache,
float *  sin_cache,
int  max_seq_len,
int  head_dim,
float  base 
)

Precompute RoPE cos/sin cache

Test:

test_rope.py::TestRoPECache::test_cache_computation

test_rope.py::TestRoPECache::test_cache_values

Precomputes cos(m * theta_i) and sin(m * theta_i) for positions 0..max_seq_len-1. cos_cache, sin_cache: [max_seq_len, head_dim/2]

After changes: make test

Definition at line 52 of file rope_kernels.c.

57 {
58  int half_dim = head_dim / 2;
59 
60  long double base_ld = (long double)base;
61  long double head_dim_ld = (long double)head_dim;
62  long double log_base = logl(base_ld);
63  for (int pos = 0; pos < max_seq_len; ++pos) {
64  for (int i = 0; i < half_dim; ++i) {
65  long double exponent = ((long double)(2 * i)) / head_dim_ld;
66  long double freq = expl(-exponent * log_base);
67  float freq_f = (float)freq;
68  float angle_f = (float)pos * freq_f;
69  cos_cache[pos * half_dim + i] = cosf(angle_f);
70  sin_cache[pos * half_dim + i] = sinf(angle_f);
71  }
72  }
73 }

Referenced by ck_test_rope().

◆ swiglu_forward()

void swiglu_forward ( const float *  input,
float *  output,
int  tokens,
int  dim 
)

SwiGLU forward pass

Test:

test_swiglu.py::TestSwiGLUForward::test_forward_tokens

test_swiglu.py::TestSwiGLUForward::test_forward_single

test_mlp.py::TestMLPForward::test_swiglu_mlp

test_fused_swiglu_decode.py::TestFusedSwiGLUDecode::test_fused_swiglu_decode

test_parity.py::test_swiglu_parity

SwiGLU: y = silu(gate) * up where silu(x) = x * sigmoid(x)

After changes: make test && make llamacpp-parity-full

Definition at line 131 of file swiglu_kernels.c.

135 {
136  int T = tokens;
137  int D = dim;
138 
139  for (int t = 0; t < T; ++t) {
140  const float *row = input + (size_t)t * (2 * D);
141  float *out_row = output + (size_t)t * D;
142  int d = 0;
143 
144 #if defined(__AVX512F__)
145  // AVX-512: Process 16 floats at a time
146  for (; d + 16 <= D; d += 16) {
147  __m512 a = _mm512_loadu_ps(&row[d]); // gate
148  __m512 b = _mm512_loadu_ps(&row[D + d]); // value
149 
150  __m512 s = sigmoid512_fast(a); // sigmoid(a)
151  __m512 silu = _mm512_mul_ps(a, s); // silu(a) = a * sigmoid(a)
152  __m512 y = _mm512_mul_ps(silu, b); // y = silu(a) * b
153 
154  _mm512_storeu_ps(&out_row[d], y);
155  }
156 #elif defined(__AVX2__)
157  // AVX2: Process 8 floats at a time
158  for (; d + 8 <= D; d += 8) {
159  __m256 a = _mm256_loadu_ps(&row[d]); // gate
160  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
161 
162  __m256 s = sigmoid256_fast(a); // sigmoid(a)
163  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * sigmoid(a)
164  __m256 y = _mm256_mul_ps(silu, b); // y = silu(a) * b
165 
166  _mm256_storeu_ps(&out_row[d], y);
167  }
168 #elif defined(__AVX__)
169  // AVX1: Vectorize arithmetic, use scalar sigmoid
170  float a_arr[8] __attribute__((aligned(32)));
171  float s_arr[8] __attribute__((aligned(32)));
172 
173  for (; d + 8 <= D; d += 8) {
174  __m256 a = _mm256_loadu_ps(&row[d]); // gate
175  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
176 
177  // Compute sigmoid scalarly
178  _mm256_store_ps(a_arr, a);
179  for (int j = 0; j < 8; ++j) {
180  s_arr[j] = sigmoid_scalar(a_arr[j]);
181  }
182  __m256 s = _mm256_load_ps(s_arr);
183 
184  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * sigmoid(a)
185  __m256 y = _mm256_mul_ps(silu, b); // y = silu(a) * b
186 
187  _mm256_storeu_ps(&out_row[d], y);
188  }
189 #endif
190 
191  // Scalar fallback for remaining elements
192  for (; d < D; ++d) {
193  float a = row[d]; // gate
194  float b = row[D + d]; // value
195 
196  float s = sigmoid_scalar(a); // sigmoid(a)
197  float silu = a * s; // silu(a) = a * sigmoid(a)
198 
199  out_row[d] = silu * b;
200  }
201  }
202 }
float sigmoid_scalar(float x)
static void silu(float *x, int n)
Definition: v6_simple.c:159

Referenced by ck_test_swiglu().

◆ vec_dot_q5_0_q8_0()

void vec_dot_q5_0_q8_0 ( int  n,
float *  s,
const void *  vx,
const void *  vy 
)

Auto-dispatch quantized dot product Q5_0 x Q8_0.

Dispatch priority:

  1. AVX512 (best performance on modern Intel/AMD)
  2. AVX (256-bit float ops, works on Sandy/Ivy Bridge and newer)
  3. SSSE3 (128-bit fallback)
  4. Reference scalar (last resort)

Definition at line 1498 of file gemm_kernels_q5_0.c.

1499 {
1500 #if defined(__AVX512F__)
1501  vec_dot_q5_0_q8_0_avx512(n, s, vx, vy);
1502 #elif defined(__AVX__)
1503  /* AVX for 256-bit float ops (works on Ivy Bridge and newer) */
1504  vec_dot_q5_0_q8_0_avx(n, s, vx, vy);
1505 #elif defined(__SSSE3__)
1506  /* SSSE3 - most efficient on older CPUs */
1507  vec_dot_q5_0_q8_0_sse(n, s, vx, vy);
1508 #else
1509  vec_dot_q5_0_q8_0_ref(n, s, vx, vy);
1510 #endif
1511 }
void vec_dot_q5_0_q8_0_ref(int n, float *s, const void *vx, const void *vy)
Quantized dot product: Q5_0 weights x Q8_0 input (scalar reference)

Referenced by ck_test_vec_dot_q5_0_q8_0().

◆ vec_dot_q8_0_q8_0()

void vec_dot_q8_0_q8_0 ( int  n,
float *  s,
const void *  vx,
const void *  vy 
)

Auto-dispatch quantized dot product Q8_0 x Q8_0.

Definition at line 1013 of file gemm_kernels_q8_0.c.

1014 {
1015 #ifdef __AVX512F__
1016  vec_dot_q8_0_q8_0_avx512(n, s, vx, vy);
1017 #elif defined(__AVX__)
1018  vec_dot_q8_0_q8_0_avx(n, s, vx, vy);
1019 #elif defined(__SSE4_1__)
1020  vec_dot_q8_0_q8_0_sse(n, s, vx, vy);
1021 #else
1022  vec_dot_q8_0_q8_0_ref(n, s, vx, vy);
1023 #endif
1024 }
void vec_dot_q8_0_q8_0_ref(int n, float *s, const void *vx, const void *vy)
Quantized dot product: Q8_0 weights x Q8_0 input (scalar reference)

Referenced by ck_test_vec_dot_q8_0_q8_0().