← Back to C-Kernel-Engine Docs Doxygen Source Documentation
ckernel_engine.h File Reference
#include <stddef.h>
#include <stdint.h>
#include "cpu_features.h"
#include "ckernel_quant.h"
#include "mega_fused_attention.h"

Go to the source code of this file.

Data Structures

struct  CKMathBackend
 

Functions

void add_backward_bf16 (const uint16_t *d_y, uint16_t *d_a, uint16_t *d_b, size_t n)
 
void add_forward_2d_bf16 (const uint16_t *a, const uint16_t *b, uint16_t *y, int tokens, int dim, int aligned_dim)
 
void add_forward_bf16 (const uint16_t *a, const uint16_t *b, uint16_t *y, size_t n)
 
void add_forward_f32 (const float *a, const float *b, float *y, size_t n)
 
void add_inplace_bf16 (uint16_t *a, const uint16_t *b, size_t n)
 
void add_inplace_f32 (float *a, const float *b, size_t n)
 
void add_scaled_forward_bf16 (const uint16_t *a, const uint16_t *b, uint16_t *y, float alpha, size_t n)
 
void add_scaled_inplace_bf16 (uint16_t *a, const uint16_t *b, float alpha, size_t n)
 
int argmax_f32 (const float *scores, int n)
 Find index of maximum value. More...
 
void attention_backward_causal_head_major (const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
 
void attention_backward_causal_head_major_gqa (const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
 
void attention_backward_causal_head_major_gqa_bf16 (const uint16_t *d_output, float *d_x, const uint16_t *q, const uint16_t *k, const uint16_t *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_d_output, float *scratch_q, float *scratch_k, float *scratch_v)
 
void attention_flash_decode (float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
 Main flash attention function with SIMD dispatch. More...
 
void attention_forward_causal_head_major (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
 
void attention_forward_causal_head_major_exact (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
 
void attention_forward_causal_head_major_gqa (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
 
void attention_forward_causal_head_major_gqa_bf16 (const uint16_t *q, const uint16_t *k, const uint16_t *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window, float *scratch_q, float *scratch_k, float *scratch_v)
 
void attention_forward_causal_head_major_gqa_exact (const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)
 
void attention_forward_causal_head_major_gqa_flash (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim)
 
void attention_forward_causal_head_major_gqa_flash_strided (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens)
 
void attention_forward_causal_head_major_gqa_flash_strided_sliding (const float *q, const float *k, const float *v, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int kv_stride_tokens, int sliding_window)
 
void attention_forward_decode_head_major_gqa_flash (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
 
void attention_forward_decode_head_major_gqa_flash_sliding (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim, int sliding_window)
 
void attention_forward_decode_head_major_gqa_regular (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
 WARNING: This is NOT true flash attention! More...
 
void axpy_2d_f32 (float *Y, const float *X, float alpha, int num_tokens, int dim, int y_stride, int x_stride)
 Batched AXPY for 2D tensors: Y[t,:] += alpha * X[t,:]. More...
 
void axpy_f32 (float *y, const float *x, float alpha, int n)
 In-place AXPY: y += alpha * x. More...
 
void axpy_zero_f32 (float *y, const float *x, float alpha, int n)
 Zero output then accumulate: y = 0; y += alpha * x. More...
 
void backward_causal_softmax_head_major (float *d_scores, const float *weights, int num_heads, int num_tokens, int aligned_context_window)
 
void backward_causal_softmax_head_major_bf16 (uint16_t *d_scores, const uint16_t *weights, int num_heads, int num_tokens, int aligned_context_window, float *scratch_d_scores, float *scratch_weights)
 
void causal_softmax_head_major (float *scores, int num_heads, int num_tokens, int aligned_context_window)
 
void causal_softmax_head_major_bf16 (uint16_t *scores, int num_heads, int num_tokens, int aligned_context_window, float *scratch)
 
void causal_softmax_head_major_exact (float *scores, int num_heads, int num_tokens, int aligned_context_window)
 
void ck_attention_flash_decode_wrapper (const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
 Wrapper to call TRUE flash attention from orchestration layer. More...
 
int ck_flash_attn_choose_tile_k (int D_h)
 
int ck_flash_attn_fast_exp_kind (void)
 
void ck_gemm_nt_head_major_q5_0 (const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
 Output projection from head-major attention (auto-dispatch) More...
 
void ck_gemm_nt_head_major_q8_0 (const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
 Output projection from head-major attention (Q8_0 weights) More...
 
int ck_get_num_threads (void)
 
int ck_get_physical_cores (void)
 
void ck_set_num_threads (int num_threads)
 
void ck_set_strict_parity (int enabled)
 
int ck_strict_parity_enabled (void)
 
CKMathBackend ckernel_backend_native (void)
 
void dequant_q4_0_row (const void *src, float *dst, size_t n_elements)
 Dequantize Q4_0 row (multiple blocks) More...
 
void dequant_q4_1_row (const void *src, float *dst, size_t n_elements)
 Dequantize Q4_1 row (multiple blocks) More...
 
void dequant_q4_k_row (const void *src, float *dst, size_t n_elements)
 Dequantize Q4_K row (multiple blocks) More...
 
void dequant_q5_0_row (const void *src, float *dst, size_t n_elements)
 Dequantize Q5_0 row (multiple blocks) More...
 
void dequant_q5_1_row (const void *src, float *dst, size_t n_elements)
 Dequantize Q5_1 row (multiple blocks) More...
 
void dequant_q6_k_row (const void *src, float *dst, size_t n_elements)
 Dequantize Q6_K row (multiple blocks) More...
 
void dequant_q8_0_row (const void *src, float *dst, size_t n_elements)
 Dequantize Q8_0 row (multiple blocks) More...
 
void embedding_backward (const int32_t *token_ids, int token_count, const float *d_output, float *d_token_embeddings, float *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 
void embedding_backward_bf16 (const int32_t *token_ids, int token_count, const uint16_t *d_output, uint16_t *d_token_embeddings, uint16_t *d_pos_embeddings, int vocab_size, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 
void embedding_forward (const int32_t *token_ids, int token_count, int vocab_size, const float *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 
void embedding_forward_bf16 (const int32_t *token_ids, int token_count, int vocab_size, const uint16_t *token_embeddings, const uint16_t *pos_embeddings, uint16_t *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 
void embedding_forward_q4_k (const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 
void embedding_forward_q6_k (const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 
void embedding_forward_q8_0 (const int32_t *token_ids, int token_count, int vocab_size, const void *token_embeddings, const float *pos_embeddings, float *output, int embed_dim, int aligned_embed_dim, int context_window, int add_pos)
 
void fc1_backward_kernel (const float *d_output, const float *fc1_input, const float *W_fc1, float *d_input, float *d_W_fc1, float *d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads)
 
void fc2_backward_kernel (const float *d_output, const float *fc2_input, const float *W_fc2, float *d_input, float *d_W_fc2, float *d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads)
 
void fused_mlp_swiglu_decode (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
 
void fused_mlp_swiglu_decode_tiled (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
 
void fused_mlp_swiglu_decode_v2 (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
 
void fused_mlp_swiglu_prefill (const float *x, const float *W_gate, const float *W_up, const float *W_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
 Fused MLP (Gate + Up + SwiGLU + Down) for prefill. More...
 
void fused_mlp_swiglu_prefill_bias (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *B_gate, const float *B_up, const float *B_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
 Fused MLP (Gate + Up + SwiGLU + Down) for prefill with biases. More...
 
void fused_mlp_swiglu_prefill_w1w2_quant (const float *x, const void *W1, const float *B1, CKDataType w1_dt, const void *W2, const float *B2, CKDataType w2_dt, float *output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void *scratch)
 Quantized fused MLP for prefill (W1=gate+up, W2=down) More...
 
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size (int aligned_embed_dim, int aligned_intermediate_dim)
 Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant. More...
 
size_t fused_mlp_swiglu_scratch_size (int intermediate)
 Get scratch buffer size for fused_mlp_swiglu_prefill. More...
 
void fused_rmsnorm_qkv_prefill (const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps, float *scratch)
 Fused RMSNorm + QKV projection for prefill. More...
 
void fused_rmsnorm_qkv_prefill_head_major (const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
 Fused RMSNorm + QKV projection for prefill (head-major outputs) More...
 
void fused_rmsnorm_qkv_prefill_head_major_quant (const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
 Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations) More...
 
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size (int aligned_embed_dim)
 Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant. More...
 
size_t fused_rmsnorm_qkv_scratch_size (int hidden)
 Get scratch buffer size for fused_rmsnorm_qkv_prefill. More...
 
void geglu_backward_fp32 (const float *x, const float *d_out, float *d_x, int tokens, int dim)
 
void geglu_forward_bf16 (const uint16_t *x, uint16_t *out, int tokens, int dim, float *scratch)
 
void geglu_forward_fp32 (const float *x, float *out, int tokens, int dim)
 
void gelu_backward_exact (const float *input, const float *d_output, float *d_input, size_t n)
 
void gelu_backward_exact_bf16 (const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input)
 
void gelu_backward_fast (const float *input, const float *d_output, float *d_input, size_t n)
 
void gelu_backward_fast_bf16 (const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input)
 
void gelu_backward_scalar (const float *input, const float *d_output, float *d_input, size_t n)
 
void gelu_exact_inplace (float *data, size_t n)
 
void gelu_fast_inplace (float *data, size_t n)
 
void gelu_fast_inplace_bf16 (uint16_t *data, size_t n, float *scratch)
 
void gemm_avx512_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_bias_gelu_fused (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_bias_relu_fused (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_bias_silu_fused (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_blocked_serial (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_blocked_serial_bf16 (const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K)
 
void gemm_fine_grained_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_microkernel (const float *A, const float *B, float *C, int M, int N, int K, int B_transposed)
 
void gemm_microkernel_blocked (const float *A, const float *B, float *C, int M, int N, int K)
 
void gemm_microkernel_blocked_bt (const float *A, const float *B, float *C, int M, int N, int K)
 
void gemm_microkernel_packed (const float *A, const float *B, float *C, int M, int N, int K)
 
void gemm_naive_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nn_avx512 (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nn_blocked (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nn_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nt_q4_0 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias. More...
 
void gemm_nt_q4_1 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 GEMM with transposed Q4_1 weights: C = A @ B^T. More...
 
void gemm_nt_q4_k (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nt_q4_k_q8_k (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nt_q5_0 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nt_q5_1 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 GEMM with transposed Q5_1 weights: C = A @ B^T. More...
 
void gemm_nt_q5_k (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nt_q6_k (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_nt_q6_k_q8_k (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
 NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K. More...
 
void gemm_nt_q8_0 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
 Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias. More...
 
void gemm_nt_q8_0_q8_0 (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
 gemm_nt_q8_0_q8_0 with optional bias (matches header signature) More...
 
void gemm_q4_k (float *Y, const void *W, const float *X, int M, int N, int K)
 Auto-dispatch GEMM based on available SIMD. More...
 
void gemm_q4_k_q8_k (float *Y, const void *W, const void *X_q8, int M, int N, int K)
 
void gemm_q6_k (float *Y, const void *W, const float *X, int M, int N, int K)
 
void gemm_q6_k_q8_k (float *Y, const void *W, const void *X_q8, int M, int N, int K)
 GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K. More...
 
void gemm_swiglu_fused (const float *x, const float *W_gate, const float *W_up, const float *b_gate, const float *b_up, float *output, int M, int N, int K)
 
void gemm_tn_avx512 (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_tn_blocked (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_tn_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemv_fused_q5_0_bias_dispatch (float *y, const void *W, const float *x, const float *bias, int M, int K)
 
void gemv_fused_q8_0_bias_dispatch (float *y, const void *W, const float *x, const float *bias, int M, int K)
 
void gemv_q4_0 (float *y, const void *W, const float *x, int M, int K)
 Auto-dispatch GEMV. More...
 
void gemv_q4_k (float *y, const void *W, const float *x, int M, int K)
 Auto-dispatch GEMV based on available SIMD. More...
 
void gemv_q4_k_q8_k (float *y, const void *W, const void *x_q8, int M, int K)
 
void gemv_q4_k_q8_k_parallel (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
 
void gemv_q4_k_q8_k_parallel_simd (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
 
void gemv_q4_k_q8_k_ref (float *y, const void *W, const void *x_q8, int M, int K)
 
void gemv_q5_0 (float *y, const void *W, const float *x, int M, int K)
 Auto-dispatch GEMV for Q5_0 weights based on CPU features. More...
 
void gemv_q5_0_parallel (float *y, const void *W, const float *x, int M, int K, int ith, int nth)
 Parallel reference GEMV for Q5_0 × FP32. More...
 
void gemv_q5_0_parallel_simd (float *y, const void *W, const float *x, int M, int K, int ith, int nth)
 Parallel SIMD GEMV for Q5_0 × FP32 with prefetching. More...
 
void gemv_q5_0_q8_0 (float *y, const void *W, const void *x_q8, int M, int K)
 Matrix-vector multiply with Q5_0 weights and Q8_0 input. More...
 
void gemv_q5_1 (float *y, const void *W, const float *x, int M, int K)
 Auto-dispatch GEMV. More...
 
void gemv_q5_k (float *y, const void *W, const float *x, int M, int K)
 
void gemv_q6_k (float *y, const void *W, const float *x, int M, int K)
 
void gemv_q6_k_q8_k (float *y, const void *W, const void *x_q8, int M, int K)
 GEMV: y = W @ x where W is Q6_K and x is Q8_K. More...
 
void gemv_q6_k_q8_k_parallel (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
 Parallel reference GEMV for Q6_K × Q8_K. More...
 
void gemv_q6_k_q8_k_parallel_simd (float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)
 Parallel SIMD GEMV for Q6_K × Q8_K. More...
 
void gemv_q8_0 (float *y, const void *W, const float *x, int M, int K)
 Auto-dispatch GEMV for Q8_0 weights based on CPU features. More...
 
void gemv_q8_0_q8_0 (float *y, const void *W, const void *x_q8, int M, int K)
 Matrix-vector multiply with Q8_0 weights and Q8_0 input. More...
 
void im2patch (const float *image, float *patches, int C, int H, int W, int P)
 
void im2patch_bf16 (const uint16_t *image, uint16_t *patches, int C, int H, int W, int P)
 
void kv_cache_repack_head_major_inplace (float *buf, int num_heads, int tokens, int cache_capacity, int aligned_head_dim)
 
void kv_cache_store (float *__restrict kv_cache_k, float *__restrict kv_cache_v, const float *__restrict k, const float *__restrict v, int layer, int pos, int num_kv_heads, int head_dim, int max_seq_len)
 
void kv_cache_write_head_major (const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)
 
void layernorm_backward_kernel (const float *d_output, const float *input, const float *gamma, const float *mean, const float *rstd, float *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim)
 
void layernorm_backward_kernel_bf16 (const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *mean, const float *rstd, uint16_t *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
 
void layernorm_forward_rolled_slice (const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps)
 
void layernorm_forward_rolled_slice_bf16 (const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
 
void layernorm_forward_unrolled_slice (const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)
 
void layernorm_forward_unrolled_slice_bf16 (const uint16_t *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, uint16_t *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps, float *scratch_input, float *scratch_output)
 
void layernorm_naive_serial (const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
 
void layernorm_naive_serial_matched_precision (const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, float eps)
 
void mlp_token_parallel (const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads)
 
void mlp_token_parallel_bf16 (const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16)
 
void mlp_token_parallel_bf16_fp32act (const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_input_f, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16)
 
void mlp_token_parallel_exact (const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads)
 
void moe_accumulate_expert_f32 (float *output, const float *expert_output, float routing_weight, int hidden_dim)
 Accumulate expert output: output += routing_weight * expert_output. More...
 
void patch2im (const float *d_patches, float *d_image, int C, int H, int W, int P)
 
void patch2im_bf16 (const uint16_t *d_patches, uint16_t *d_image, int C, int H, int W, int P)
 
void quantize_batch_q8_0 (const float *x, void *y, int num_rows, int k)
 Batch quantize FP32 to Q8_0 format (row-major output) More...
 
void quantize_batch_q8_k (const float *x, void *y, int num_rows, int k)
 Batch quantize FP32 to Q8_K format (row-major output) More...
 
void quantize_row_q8_0 (const float *x, void *y, int k)
 Quantize FP32 to Q8_0 format (scalar reference) More...
 
void quantize_row_q8_k (const float *x, void *y, int k)
 
void relu_backward (const float *input, const float *d_output, float *d_input, size_t n)
 
void relu_backward_bf16 (const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n)
 
void relu_forward (const float *input, float *output, size_t n)
 
void relu_forward_bf16 (const uint16_t *input, uint16_t *output, size_t n)
 
void relu_forward_inplace (float *data, size_t n)
 
void relu_forward_inplace_bf16 (uint16_t *data, size_t n)
 
void rmsnorm_backward (const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
 
void rmsnorm_backward_bf16 (const uint16_t *d_output, const uint16_t *input, const float *gamma, const float *rstd_cache, uint16_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
 
void rmsnorm_backward_int4 (const uint8_t *d_output, const uint8_t *input, const float *gamma, const float *rstd_cache, uint8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
 
void rmsnorm_backward_int8 (const int8_t *d_output, const int8_t *input, const float *gamma, const float *rstd_cache, int8_t *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim, float *scratch_d_output, float *scratch_input, float *scratch_d_input)
 
void rmsnorm_forward (const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
 
void rmsnorm_forward_bf16 (const uint16_t *input, const float *gamma, uint16_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)
 
void rmsnorm_forward_int4 (const uint8_t *input, const float *gamma, uint8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
 
void rmsnorm_forward_int8 (const int8_t *input, const float *gamma, int8_t *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps, float *scratch_input, float *scratch_output)
 
void rope_backward (const float *d_out, float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
 
void rope_backward_bf16 (const uint16_t *d_out, uint16_t *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_d_out, float *scratch_d_x)
 
void rope_backward_inplace (float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
 
void rope_backward_qk (const float *d_q_out, const float *d_k_out, float *d_q, float *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
 
void rope_backward_qk_bf16 (const uint16_t *d_q_out, const uint16_t *d_k_out, uint16_t *d_q, uint16_t *d_k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_dq_out, float *scratch_dq, float *scratch_dk_out, float *scratch_dk)
 
void rope_forward (float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
 
void rope_forward_bf16 (uint16_t *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch)
 
void rope_forward_qk (float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
 
void rope_forward_qk_bf16 (uint16_t *q, uint16_t *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_q, float *scratch_k)
 
void rope_forward_qk_strided (float *q, float *k, const float *cos_cache, const float *sin_cache, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int q_stride_tokens, int k_stride_tokens)
 
void rope_forward_strided (float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int head_stride_tokens)
 
void rope_precompute_cache (float *cos_cache, float *sin_cache, int max_seq_len, int head_dim, float base)
 
void scal_copy_f32 (float *y, const float *x, float alpha, int n)
 Scaled copy: y = alpha * x. More...
 
void sigmoid_backward (const float *input, const float *d_output, float *d_input, size_t n)
 
void sigmoid_backward_bf16 (const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, size_t n, float *scratch_input, float *scratch_d_output, float *scratch_d_input)
 
void sigmoid_forward (const float *input, float *output, size_t n)
 
void sigmoid_forward_bf16 (const uint16_t *input, uint16_t *output, size_t n, float *scratch_input, float *scratch_output)
 
float sigmoid_scalar (float x)
 
void softmax_cross_entropy_loss (const float *logits, const int32_t *targets, int tokens, int vocab_size, float *d_logits, float *loss_out)
 
void softmax_cross_entropy_loss_bf16 (const uint16_t *logits, const int32_t *targets, int tokens, int vocab_size, uint16_t *d_logits, float *loss_out, float *scratch_logits, float *scratch_d_logits)
 
void swiglu_backward (const float *input, const float *d_output, float *d_input, int tokens, int dim)
 
void swiglu_backward_bf16 (const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, int tokens, int dim)
 
void swiglu_backward_exact (const float *input, const float *d_output, float *d_input, int tokens, int dim)
 
void swiglu_forward (const float *input, float *output, int tokens, int dim)
 
void swiglu_forward_bf16 (const uint16_t *input, uint16_t *output, int tokens, int dim)
 
void swiglu_forward_exact (const float *input, float *output, int tokens, int dim)
 
void topk_batched_f32 (const float *scores, int num_tokens, int n_experts, int k, int *indices, float *weights)
 Batched top-K selection for multiple tokens. More...
 
void topk_f32 (const float *scores, int n, int k, int *indices, float *values)
 Find top-K indices and values from a score vector. More...
 
void topk_softmax_f32 (const float *scores, int n, int k, int *indices, float *weights)
 Find top-K indices with softmax-normalized weights. More...
 
void unfused_rmsnorm_qkv_prefill (const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *x_norm, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps)
 Unfused version for benchmarking comparison. More...
 
void vec_dot_q6_k_q8_k (int n, float *s, const void *vx, const void *vy)
 Q6_K x Q8_K dot product (single row) More...
 
void weighted_sum_f32 (float *y, const float **vectors, const float *weights, int k, int n)
 Weighted sum of k vectors: y = sum_i(weights[i] * vectors[i]) More...
 

Function Documentation

◆ add_backward_bf16()

void add_backward_bf16 ( const uint16_t *  d_y,
uint16_t *  d_a,
uint16_t *  d_b,
size_t  n 
)

Definition at line 173 of file add_kernels_bf16.c.

177 {
178  if (!d_y || n == 0) {
179  return;
180  }
181 
182  size_t i = 0;
183 
184  /* Copy to d_a if not in-place */
185  if (d_a && d_a != d_y) {
186 #if defined(__AVX512F__)
187  for (; i + 32 <= n; i += 32) {
188  __m512i v0 = _mm512_loadu_si512((const __m512i*)&d_y[i]);
189  __m512i v1 = _mm512_loadu_si512((const __m512i*)&d_y[i + 32]);
190  _mm512_storeu_si512((__m512i*)&d_a[i], v0);
191  _mm512_storeu_si512((__m512i*)&d_a[i + 32], v1);
192  }
193 #endif
194  for (; i < n; ++i) {
195  d_a[i] = d_y[i];
196  }
197  }
198 
199  /* Copy to d_b if not in-place */
200  i = 0;
201  if (d_b && d_b != d_y) {
202 #if defined(__AVX512F__)
203  for (; i + 32 <= n; i += 32) {
204  __m512i v0 = _mm512_loadu_si512((const __m512i*)&d_y[i]);
205  __m512i v1 = _mm512_loadu_si512((const __m512i*)&d_y[i + 32]);
206  _mm512_storeu_si512((__m512i*)&d_b[i], v0);
207  _mm512_storeu_si512((__m512i*)&d_b[i + 32], v1);
208  }
209 #endif
210  for (; i < n; ++i) {
211  d_b[i] = d_y[i];
212  }
213  }
214 }

◆ add_forward_2d_bf16()

void add_forward_2d_bf16 ( const uint16_t *  a,
const uint16_t *  b,
uint16_t *  y,
int  tokens,
int  dim,
int  aligned_dim 
)

Definition at line 221 of file add_kernels_bf16.c.

227 {
228  if (!a || !b || !y || tokens <= 0 || dim <= 0) {
229  return;
230  }
231 
232  for (int t = 0; t < tokens; ++t) {
233  const uint16_t *a_row = a + (size_t)t * aligned_dim;
234  const uint16_t *b_row = b + (size_t)t * aligned_dim;
235  uint16_t *y_row = y + (size_t)t * aligned_dim;
236 
237  int d = 0;
238 
239 #if defined(__AVX512F__)
240  for (; d + 16 <= dim; d += 16) {
241  __m512 av = bf16_loadu_cvt_fp32(&a_row[d]);
242  __m512 bv = bf16_loadu_cvt_fp32(&b_row[d]);
243  __m512 yv = _mm512_add_ps(av, bv);
244  fp32_cvt_storeu_bf16(&y_row[d], yv);
245  }
246 #endif
247 
248  for (; d < dim; ++d) {
249  float af = bf16_to_float(a_row[d]);
250  float bf = bf16_to_float(b_row[d]);
251  y_row[d] = float_to_bf16(af + bf);
252  }
253  }
254 }
static uint16_t float_to_bf16(float f)
Definition: bf16_utils.h:90
static float bf16_to_float(uint16_t v)
Definition: bf16_utils.h:38

References bf16_to_float(), and float_to_bf16().

◆ add_forward_bf16()

void add_forward_bf16 ( const uint16_t *  a,
const uint16_t *  b,
uint16_t *  y,
size_t  n 
)

Definition at line 38 of file add_kernels_bf16.c.

42 {
43  if (!a || !b || !y || n == 0) {
44  return;
45  }
46 
47  size_t i = 0;
48 
49 #if defined(__AVX512F__)
50  /* AVX-512: Process 16 bf16 elements at a time */
51  for (; i + 16 <= n; i += 16) {
52  __m512 av = bf16_loadu_cvt_fp32(&a[i]);
53  __m512 bv = bf16_loadu_cvt_fp32(&b[i]);
54  __m512 yv = _mm512_add_ps(av, bv);
55  fp32_cvt_storeu_bf16(&y[i], yv);
56  }
57 #endif
58 
59  /* Scalar fallback */
60  for (; i < n; ++i) {
61  float af = bf16_to_float(a[i]);
62  float bf = bf16_to_float(b[i]);
63  y[i] = float_to_bf16(af + bf);
64  }
65 }

References bf16_to_float(), and float_to_bf16().

◆ add_forward_f32()

void add_forward_f32 ( const float *  a,
const float *  b,
float *  y,
size_t  n 
)

Element-wise add: y = a + b

Test:

test_add.py::TestAddForward::test_add_forward_f32

test_add.py::TestAddForward::test_add_inplace_f32

test_multi_layer_parity.py::TestMultiLayerParity::test_residual_add

Element-wise addition of two vectors.

After changes: make test

Definition at line 270 of file add_kernels_bf16.c.

274 {
275  if (!a || !b || !y || n == 0) {
276  return;
277  }
278 
279  size_t i = 0;
280 
281 #if defined(__AVX512F__)
282  for (; i + 16 <= n; i += 16) {
283  __m512 av = _mm512_loadu_ps(&a[i]);
284  __m512 bv = _mm512_loadu_ps(&b[i]);
285  __m512 yv = _mm512_add_ps(av, bv);
286  _mm512_storeu_ps(&y[i], yv);
287  }
288 #endif
289 
290 #if defined(__AVX2__)
291  for (; i + 8 <= n; i += 8) {
292  __m256 av = _mm256_loadu_ps(&a[i]);
293  __m256 bv = _mm256_loadu_ps(&b[i]);
294  __m256 yv = _mm256_add_ps(av, bv);
295  _mm256_storeu_ps(&y[i], yv);
296  }
297 #endif
298 
299  for (; i < n; ++i) {
300  y[i] = a[i] + b[i];
301  }
302 }

◆ add_inplace_bf16()

void add_inplace_bf16 ( uint16_t *  a,
const uint16_t *  b,
size_t  n 
)

Definition at line 105 of file add_kernels_bf16.c.

108 {
109  if (!a || !b || n == 0) {
110  return;
111  }
112 
113  size_t i = 0;
114 
115 #if defined(__AVX512F__)
116  for (; i + 16 <= n; i += 16) {
117  __m512 av = bf16_loadu_cvt_fp32(&a[i]);
118  __m512 bv = bf16_loadu_cvt_fp32(&b[i]);
119  __m512 yv = _mm512_add_ps(av, bv);
120  fp32_cvt_storeu_bf16(&a[i], yv);
121  }
122 #endif
123 
124  for (; i < n; ++i) {
125  float af = bf16_to_float(a[i]);
126  float bf = bf16_to_float(b[i]);
127  a[i] = float_to_bf16(af + bf);
128  }
129 }

References bf16_to_float(), and float_to_bf16().

◆ add_inplace_f32()

void add_inplace_f32 ( float *  a,
const float *  b,
size_t  n 
)

Definition at line 304 of file add_kernels_bf16.c.

307 {
308  if (!a || !b || n == 0) {
309  return;
310  }
311 
312  size_t i = 0;
313 
314 #if defined(__AVX512F__)
315  for (; i + 16 <= n; i += 16) {
316  __m512 av = _mm512_loadu_ps(&a[i]);
317  __m512 bv = _mm512_loadu_ps(&b[i]);
318  __m512 yv = _mm512_add_ps(av, bv);
319  _mm512_storeu_ps(&a[i], yv);
320  }
321 #endif
322 
323 #if defined(__AVX2__)
324  for (; i + 8 <= n; i += 8) {
325  __m256 av = _mm256_loadu_ps(&a[i]);
326  __m256 bv = _mm256_loadu_ps(&b[i]);
327  __m256 yv = _mm256_add_ps(av, bv);
328  _mm256_storeu_ps(&a[i], yv);
329  }
330 #endif
331 
332  for (; i < n; ++i) {
333  a[i] = a[i] + b[i];
334  }
335 }

Referenced by mega_fused_outproj_mlp_prefill().

◆ add_scaled_forward_bf16()

void add_scaled_forward_bf16 ( const uint16_t *  a,
const uint16_t *  b,
uint16_t *  y,
float  alpha,
size_t  n 
)

Definition at line 72 of file add_kernels_bf16.c.

77 {
78  if (!a || !b || !y || n == 0) {
79  return;
80  }
81 
82  size_t i = 0;
83 
84 #if defined(__AVX512F__)
85  __m512 alpha_v = _mm512_set1_ps(alpha);
86  for (; i + 16 <= n; i += 16) {
87  __m512 av = bf16_loadu_cvt_fp32(&a[i]);
88  __m512 bv = bf16_loadu_cvt_fp32(&b[i]);
89  __m512 yv = _mm512_fmadd_ps(bv, alpha_v, av); /* a + alpha * b */
90  fp32_cvt_storeu_bf16(&y[i], yv);
91  }
92 #endif
93 
94  for (; i < n; ++i) {
95  float af = bf16_to_float(a[i]);
96  float bf = bf16_to_float(b[i]);
97  y[i] = float_to_bf16(af + alpha * bf);
98  }
99 }

References bf16_to_float(), and float_to_bf16().

◆ add_scaled_inplace_bf16()

void add_scaled_inplace_bf16 ( uint16_t *  a,
const uint16_t *  b,
float  alpha,
size_t  n 
)

Definition at line 135 of file add_kernels_bf16.c.

139 {
140  if (!a || !b || n == 0) {
141  return;
142  }
143 
144  size_t i = 0;
145 
146 #if defined(__AVX512F__)
147  __m512 alpha_v = _mm512_set1_ps(alpha);
148  for (; i + 16 <= n; i += 16) {
149  __m512 av = bf16_loadu_cvt_fp32(&a[i]);
150  __m512 bv = bf16_loadu_cvt_fp32(&b[i]);
151  __m512 yv = _mm512_fmadd_ps(bv, alpha_v, av);
152  fp32_cvt_storeu_bf16(&a[i], yv);
153  }
154 #endif
155 
156  for (; i < n; ++i) {
157  float af = bf16_to_float(a[i]);
158  float bf = bf16_to_float(b[i]);
159  a[i] = float_to_bf16(af + alpha * bf);
160  }
161 }

References bf16_to_float(), and float_to_bf16().

◆ argmax_f32()

int argmax_f32 ( const float *  scores,
int  n 
)

Find index of maximum value.

Parameters
scoresInput scores [n]
nNumber of scores
Returns
Index of maximum value

Definition at line 226 of file topk_kernels.c.

227 {
228  if (!scores || n <= 0) {
229  return -1;
230  }
231 
232  int max_idx = 0;
233  float max_val = scores[0];
234 
235 #ifdef __AVX512F__
236  /* AVX-512 vectorized argmax for large arrays */
237  if (n >= 16) {
238  __m512 vmax = _mm512_set1_ps(-FLT_MAX);
239  __m512i vidx = _mm512_setzero_si512();
240  __m512i vcur_max_idx = _mm512_setzero_si512();
241 
242  int i = 0;
243  for (; i + 16 <= n; i += 16) {
244  __m512 v = _mm512_loadu_ps(&scores[i]);
245  __m512i cur_idx = _mm512_add_epi32(
246  _mm512_set1_epi32(i),
247  _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
248  );
249 
250  __mmask16 gt_mask = _mm512_cmp_ps_mask(v, vmax, _CMP_GT_OQ);
251  vmax = _mm512_mask_blend_ps(gt_mask, vmax, v);
252  vcur_max_idx = _mm512_mask_blend_epi32(gt_mask, vcur_max_idx, cur_idx);
253  }
254 
255  /* Horizontal reduction */
256  float vals[16];
257  int idxs[16];
258  _mm512_storeu_ps(vals, vmax);
259  _mm512_storeu_si512(idxs, vcur_max_idx);
260 
261  max_val = vals[0];
262  max_idx = idxs[0];
263  for (int j = 1; j < 16; j++) {
264  if (vals[j] > max_val) {
265  max_val = vals[j];
266  max_idx = idxs[j];
267  }
268  }
269 
270  /* Handle remainder */
271  for (; i < n; i++) {
272  if (scores[i] > max_val) {
273  max_val = scores[i];
274  max_idx = i;
275  }
276  }
277 
278  return max_idx;
279  }
280 #endif
281 
282  /* Scalar fallback */
283  for (int i = 1; i < n; i++) {
284  if (scores[i] > max_val) {
285  max_val = scores[i];
286  max_idx = i;
287  }
288  }
289 
290  return max_idx;
291 }

◆ attention_backward_causal_head_major()

void attention_backward_causal_head_major ( const float *  d_output,
const float *  q,
const float *  k,
const float *  v,
const float *  attn_weights,
float *  d_q,
float *  d_k,
float *  d_v,
float *  d_scores,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window 
)

Causal attention backward (non-GQA version)

Test:

test_attention_backward.py::TestAttentionBackward::test_backward

test_attention_backward.py::TestAttentionBackward::test_backward_vs_separate

test_parity.py::test_attention_backward_parity

Non-GQA version where num_heads == num_kv_heads. Simpler than GQA, no head broadcasting needed.

After changes: make test && make llamacpp-parity-full

Definition at line 1811 of file attention_kernels.c.

1826 {
1828  d_output, q, k, v, attn_weights,
1829  d_q, d_k, d_v, d_scores,
1830  num_heads, num_heads, // num_kv_heads == num_heads
1831  num_tokens, head_dim, aligned_head_dim, aligned_context_window);
1832 }
void attention_backward_causal_head_major_gqa(const float *d_output, const float *q, const float *k, const float *v, const float *attn_weights, float *d_q, float *d_k, float *d_v, float *d_scores, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)

References attention_backward_causal_head_major_gqa().

◆ attention_backward_causal_head_major_gqa()

void attention_backward_causal_head_major_gqa ( const float *  d_output,
const float *  q,
const float *  k,
const float *  v,
const float *  attn_weights,
float *  d_q,
float *  d_k,
float *  d_v,
float *  d_scores,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window 
)

GQA causal attention backward (score-matrix version)

Test:

test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward

test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_vs_separate

test_parity.py::test_attention_backward_parity

Computes dQ, dK, dV given dOutput and attention weights. Supports grouped-query attention with head broadcasting.

After changes: make test && make llamacpp-parity-full

Definition at line 1672 of file attention_kernels.c.

1688 {
1689  const float scale = 1.0f / sqrtf((float)head_dim);
1690  int T = num_tokens;
1691  int H = num_heads;
1692  int H_kv = num_kv_heads;
1693  int hd = head_dim;
1694  int ad = aligned_head_dim;
1695  int aw = aligned_context_window;
1696 
1697  const size_t d_q_elems = (size_t)H * (size_t)T * (size_t)ad;
1698  const size_t kv_elems = (size_t)H_kv * (size_t)T * (size_t)ad;
1699  /* Zero the aligned outputs so padded lanes never leak garbage to downstream GEMMs. */
1700  for (size_t idx = 0; idx < d_q_elems; ++idx) {
1701  d_q[idx] = 0.0f;
1702  }
1703  for (size_t idx = 0; idx < kv_elems; ++idx) {
1704  d_k[idx] = 0.0f;
1705  d_v[idx] = 0.0f;
1706  }
1707 
1708  // Process each query head
1709  for (int h = 0; h < H; ++h) {
1710  // Which KV head does this query head use?
1711  int kv_h = (int)((long long)h * (long long)H_kv / (long long)H);
1712 
1713  // ----------------------------------------------------------------
1714  // Step 1: d_weights = d_output @ V^T and d_v += weights^T @ d_output
1715  // ----------------------------------------------------------------
1716  // For each query position i, compute d_weights[i, j] for j <= i
1717  // and accumulate d_v[j] contributions
1718 
1719  for (int i = 0; i < T; ++i) {
1720  size_t d_out_base = qkv_index(h, i, 0, T, ad);
1721 
1722  for (int j = 0; j <= i; ++j) {
1723  size_t v_base = qkv_index(kv_h, j, 0, T, ad);
1724  size_t w_idx = score_index(h, i, j, aw);
1725  float w = attn_weights[w_idx];
1726 
1727  // d_weights[h, i, j] = d_output[h, i, :] @ v[kv_h, j, :]^T
1728  float dot = 0.0f;
1729  for (int dd = 0; dd < hd; ++dd) {
1730  dot += d_output[d_out_base + dd] * v[v_base + dd];
1731  }
1732  d_scores[w_idx] = dot;
1733 
1734  // d_v[kv_h, j, :] += weights[h, i, j] * d_output[h, i, :]
1735  for (int dd = 0; dd < hd; ++dd) {
1736  d_v[v_base + dd] += w * d_output[d_out_base + dd];
1737  }
1738  }
1739 
1740  // Zero out upper triangle of d_scores
1741  for (int j = i + 1; j < T; ++j) {
1742  d_scores[score_index(h, i, j, aw)] = 0.0f;
1743  }
1744  /* Scores scratch uses aligned_context_window, zero the padded columns. */
1745  for (int j = T; j < aw; ++j) {
1746  d_scores[score_index(h, i, j, aw)] = 0.0f;
1747  }
1748  }
1749 
1750  // ----------------------------------------------------------------
1751  // Step 2: Backward through softmax (in-place on d_scores for this head)
1752  // ----------------------------------------------------------------
1753  // d_scores = softmax_backward(d_scores, attn_weights)
1754  // Formula: d_score[i,j] = w[i,j] * (d_w[i,j] - sum_k(w[i,k] * d_w[i,k]))
1755 
1756  for (int i = 0; i < T; ++i) {
1757  int base = h * aw * aw + i * aw;
1758 
1759  // Compute dot product: sum_j w[i,j] * d_w[i,j]
1760  float dot_product = 0.0f;
1761  for (int j = 0; j <= i; ++j) {
1762  float wt = attn_weights[base + j];
1763  float dw = d_scores[base + j];
1764  dot_product += wt * dw;
1765  }
1766 
1767  // Apply softmax backward formula
1768  for (int j = 0; j <= i; ++j) {
1769  float wt = attn_weights[base + j];
1770  float dw = d_scores[base + j];
1771  d_scores[base + j] = wt * (dw - dot_product);
1772  }
1773  }
1774 
1775  // ----------------------------------------------------------------
1776  // Step 3: d_q = d_scores @ K * scale
1777  // d_k += d_scores^T @ Q * scale
1778  // ----------------------------------------------------------------
1779 
1780  for (int i = 0; i < T; ++i) {
1781  size_t d_q_base = qkv_index(h, i, 0, T, ad);
1782  size_t q_base = qkv_index(h, i, 0, T, ad);
1783 
1784  // d_q[h, i, :] = sum_j d_scores[h, i, j] * k[kv_h, j, :] * scale
1785  // d_k[kv_h, j, :] += d_scores[h, i, j] * q[h, i, :] * scale
1786  for (int j = 0; j <= i; ++j) {
1787  size_t k_base = qkv_index(kv_h, j, 0, T, ad);
1788  size_t d_k_base = qkv_index(kv_h, j, 0, T, ad);
1789  float ds = d_scores[score_index(h, i, j, aw)] * scale;
1790 
1791  for (int dd = 0; dd < hd; ++dd) {
1792  d_q[d_q_base + dd] += ds * k[k_base + dd];
1793  d_k[d_k_base + dd] += ds * q[q_base + dd];
1794  }
1795  }
1796  }
1797  }
1798 }
static size_t qkv_index(int h, int t, int d, int num_tokens, int aligned_head_dim)
static size_t score_index(int h, int i, int j, int aligned_context_window)

References qkv_index(), and score_index().

Referenced by attention_backward_causal_head_major(), attention_backward_causal_head_major_gqa_bf16(), and ck_layer_backward_rmsnorm_swiglu().

◆ attention_backward_causal_head_major_gqa_bf16()

void attention_backward_causal_head_major_gqa_bf16 ( const uint16_t *  d_output,
float *  d_x,
const uint16_t *  q,
const uint16_t *  k,
const uint16_t *  v,
const float *  attn_weights,
float *  d_q,
float *  d_k,
float *  d_v,
float *  d_scores,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window,
float *  scratch_d_output,
float *  scratch_q,
float *  scratch_k,
float *  scratch_v 
)

BF16 attention backward with caller-provided scratch buffers

Test:
bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_backward

Accepts BF16 inputs, converts to FP32, runs FP32 backward. Caller provides scratch buffers (no per-call malloc).

After changes: make test

Definition at line 1619 of file attention_kernels.c.

1640 {
1641  (void)d_x;
1642  const size_t head_elems = (size_t)num_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
1643  const size_t kv_elems = (size_t)num_kv_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
1644 
1645  if (!scratch_d_output || !scratch_q || !scratch_k || !scratch_v) return;
1646 
1647  convert_bf16_tensor_to_buf(d_output, scratch_d_output, head_elems);
1648  convert_bf16_tensor_to_buf(q, scratch_q, head_elems);
1649  convert_bf16_tensor_to_buf(k, scratch_k, kv_elems);
1650  convert_bf16_tensor_to_buf(v, scratch_v, kv_elems);
1651 
1652  attention_backward_causal_head_major_gqa(scratch_d_output, scratch_q, scratch_k, scratch_v,
1653  attn_weights,
1654  d_q, d_k, d_v, d_scores,
1655  num_heads, num_kv_heads,
1656  num_tokens, head_dim,
1657  aligned_head_dim, aligned_context_window);
1658  /* No free - caller owns scratch buffers */
1659 }
static void convert_bf16_tensor_to_buf(const uint16_t *src, float *dst, size_t count)

References attention_backward_causal_head_major_gqa(), and convert_bf16_tensor_to_buf().

◆ attention_flash_decode()

void attention_flash_decode ( float *  out,
const float *  q,
const float *  k,
const float *  v,
int  T_q,
int  T_k,
int  H,
int  D_h,
float  scale 
)

Main flash attention function with SIMD dispatch.

Parameters
outOutput [T_q, H, D_h]
qQuery [T_q, H, D_h]
kKey [T_k, H, D_h]
vValue [T_k, H, D_h]
T_qNumber of query tokens (1 for decode)
T_kNumber of key/value tokens (context length)
HNumber of heads
D_hHead dimension
scale1/sqrt(D_h)

Definition at line 696 of file attention_flash_true.c.

706 {
707  if (!out || !q || !k || !v) {
708  return;
709  }
710  if (T_q <= 0 || T_k <= 0 || H <= 0 || D_h <= 0) {
711  return;
712  }
713 
714  // Dispatch based on CPU features
715 #if defined(__AVX512F__)
716  attention_flash_decode_avx512(out, q, k, v, T_q, T_k, H, D_h, scale);
717 #elif defined(__AVX__) && !defined(__AVX512F__)
718  attention_flash_decode_avx(out, q, k, v, T_q, T_k, H, D_h, scale);
719 #else
720  attention_flash_decode_scalar(out, q, k, v, T_q, T_k, H, D_h, scale);
721 #endif
722 }
static void attention_flash_decode_scalar(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Scalar flash-style attention (online softmax)

References attention_flash_decode_scalar().

Referenced by attention_forward_decode_head_major_gqa_flash(), ck_attention_flash_decode_wrapper(), mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().

◆ attention_forward_causal_head_major()

void attention_forward_causal_head_major ( const float *  q,
const float *  k,
const float *  v,
float *  scores,
float *  output,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window 
)

Causal attention forward (score-matrix version)

Test:

test_attention.py::TestAttentionForward::test_causal_forward

test_attention.py::TestAttentionForward::test_gqa_broadcast

test_attention.py::TestAttentionForward::test_exact_vs_fast

test_parity.py::test_attention_parity

Computes softmax(Q @ K^T / sqrt(d)) @ V with causal masking. Uses O(N^2) memory for scores matrix.

After changes: make test && make llamacpp-parity-full

Definition at line 70 of file attention_kernels.c.

80 {
81  const float scale = 1.0f / sqrtf((float)head_dim);
82 
83  // Phase 1: compute scaled dot-product scores Q·K^T / sqrt(d_k),
84  // lower triangle only (j <= i).
85  for (int h = 0; h < num_heads; ++h) {
86  for (int i = 0; i < num_tokens; ++i) {
87  for (int j = 0; j <= i; ++j) {
88  float dot = 0.0f;
89  size_t base_q = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
90  size_t base_k = qkv_index(h, j, 0, num_tokens, aligned_head_dim);
91 
92  for (int d = 0; d < head_dim; ++d) {
93  dot += q[base_q + d] * k[base_k + d];
94  }
95 
96  scores[score_index(h, i, j, aligned_context_window)] = dot * scale;
97  }
98 
99  // Ensure upper triangle is zeroed so there are no stale values
100  // before the softmax kernel runs.
101  for (int j = i + 1; j < num_tokens; ++j) {
102  scores[score_index(h, i, j, aligned_context_window)] = 0.0f;
103  }
104  }
105  }
106 
107  // Phase 2: apply causal row-wise softmax in-place over j <= i.
109  num_heads,
110  num_tokens,
111  aligned_context_window);
112 
113  // Phase 3: attention weights · V.
114  for (int h = 0; h < num_heads; ++h) {
115  for (int i = 0; i < num_tokens; ++i) {
116  size_t out_base = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
117 
118  // Zero the full aligned head slice so padded dims stay clean.
119  for (int d = 0; d < aligned_head_dim; ++d) {
120  output[out_base + d] = 0.0f;
121  }
122 
123  // Weighted sum over causal positions.
124  for (int j = 0; j <= i; ++j) {
125  float w = scores[score_index(h, i, j, aligned_context_window)];
126  size_t v_base = qkv_index(h, j, 0, num_tokens, aligned_head_dim);
127 
128  for (int d = 0; d < head_dim; ++d) {
129  output[out_base + d] += w * v[v_base + d];
130  }
131  }
132  }
133  }
134 }
void causal_softmax_head_major(float *scores, int num_heads, int num_tokens, int aligned_context_window)

References causal_softmax_head_major(), qkv_index(), and score_index().

◆ attention_forward_causal_head_major_exact()

void attention_forward_causal_head_major_exact ( const float *  q,
const float *  k,
const float *  v,
float *  scores,
float *  output,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window 
)

Causal attention forward (exact version using stdlib expf)

Test:

test_attention.py::TestAttentionForward::test_exact_single

test_attention.py::TestAttentionForward::test_exact_vs_fast

Uses standard library expf for numerical accuracy reference. Slower but provides maximum accuracy.

After changes: make test

Definition at line 146 of file attention_kernels.c.

156 {
157  const float scale = 1.0f / sqrtf((float)head_dim);
158 
159  // Phase 1: compute scaled dot-product scores Q·K^T / sqrt(d_k),
160  // lower triangle only (j <= i).
161  for (int h = 0; h < num_heads; ++h) {
162  for (int i = 0; i < num_tokens; ++i) {
163  for (int j = 0; j <= i; ++j) {
164  float dot = 0.0f;
165  size_t base_q = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
166  size_t base_k = qkv_index(h, j, 0, num_tokens, aligned_head_dim);
167 
168  for (int d = 0; d < head_dim; ++d) {
169  dot += q[base_q + d] * k[base_k + d];
170  }
171 
172  scores[score_index(h, i, j, aligned_context_window)] = dot * scale;
173  }
174 
175  // Ensure upper triangle is zeroed so there are no stale values
176  // before the softmax kernel runs.
177  for (int j = i + 1; j < num_tokens; ++j) {
178  scores[score_index(h, i, j, aligned_context_window)] = 0.0f;
179  }
180  }
181  }
182 
183  // Phase 2: apply causal row-wise softmax using exact expf.
185  num_heads,
186  num_tokens,
187  aligned_context_window);
188 
189  // Phase 3: attention weights · V.
190  for (int h = 0; h < num_heads; ++h) {
191  for (int i = 0; i < num_tokens; ++i) {
192  size_t out_base = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
193 
194  // Zero the full aligned head slice so padded dims stay clean.
195  for (int d = 0; d < aligned_head_dim; ++d) {
196  output[out_base + d] = 0.0f;
197  }
198 
199  // Weighted sum over causal positions.
200  for (int j = 0; j <= i; ++j) {
201  float w = scores[score_index(h, i, j, aligned_context_window)];
202  size_t v_base = qkv_index(h, j, 0, num_tokens, aligned_head_dim);
203 
204  for (int d = 0; d < head_dim; ++d) {
205  output[out_base + d] += w * v[v_base + d];
206  }
207  }
208  }
209  }
210 }
void causal_softmax_head_major_exact(float *scores, int num_heads, int num_tokens, int aligned_context_window)

References causal_softmax_head_major_exact(), qkv_index(), and score_index().

◆ attention_forward_causal_head_major_gqa()

void attention_forward_causal_head_major_gqa ( const float *  q,
const float *  k,
const float *  v,
float *  scores,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window 
)

GQA causal attention forward (score-matrix version)

Test:

test_attention.py::TestAttentionForward::test_gqa_forward

test_attention.py::TestAttentionForward::test_gqa_broadcast

test_attention_backward.py::TestAttentionBackwardGQA::test_gqa_backward

test_parity.py::test_attention_gqa_parity

Grouped-query attention: Q has num_heads, K/V have num_kv_heads. Each query head maps to a KV head via ratio.

After changes: make test && make llamacpp-parity-full

Definition at line 224 of file attention_kernels.c.

235 {
236  const float scale = 1.0f / sqrtf((float)head_dim);
237 
238  for (int h = 0; h < num_heads; ++h) {
239  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
240  for (int i = 0; i < num_tokens; ++i) {
241  for (int j = 0; j <= i; ++j) {
242  float dot = 0.0f;
243  size_t base_q = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
244  size_t base_k = qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
245 
246  for (int d = 0; d < head_dim; ++d) {
247  dot += q[base_q + d] * k[base_k + d];
248  }
249 
250  scores[score_index(h, i, j, aligned_context_window)] = dot * scale;
251  }
252 
253  for (int j = i + 1; j < num_tokens; ++j) {
254  scores[score_index(h, i, j, aligned_context_window)] = 0.0f;
255  }
256  }
257  }
258 
260  num_heads,
261  num_tokens,
262  aligned_context_window);
263 
264  for (int h = 0; h < num_heads; ++h) {
265  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
266  for (int i = 0; i < num_tokens; ++i) {
267  size_t out_base = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
268  for (int d = 0; d < aligned_head_dim; ++d) {
269  output[out_base + d] = 0.0f;
270  }
271 
272  for (int j = 0; j <= i; ++j) {
273  float w = scores[score_index(h, i, j, aligned_context_window)];
274  size_t v_base = qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
275 
276  for (int d = 0; d < head_dim; ++d) {
277  output[out_base + d] += w * v[v_base + d];
278  }
279  }
280  }
281  }
282 }

References causal_softmax_head_major(), qkv_index(), and score_index().

Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), and ck_layer_forward_rmsnorm_swiglu_ref().

◆ attention_forward_causal_head_major_gqa_bf16()

void attention_forward_causal_head_major_gqa_bf16 ( const uint16_t *  q,
const uint16_t *  k,
const uint16_t *  v,
float *  scores,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window,
float *  scratch_q,
float *  scratch_k,
float *  scratch_v 
)

BF16 GQA causal attention forward

Test:

bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_forward

bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa

bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_flash

Accepts BF16 inputs, converts to FP32, uses exact softmax. Caller provides scratch buffers (no per-call malloc).

After changes: make test

Definition at line 366 of file attention_kernels.c.

380 {
381  const size_t q_elems = (size_t)num_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
382  const size_t kv_elems = (size_t)num_kv_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
383 
384  if (!scratch_q || !scratch_k || !scratch_v) return;
385 
386  convert_bf16_tensor_to_buf(q, scratch_q, q_elems);
387  convert_bf16_tensor_to_buf(k, scratch_k, kv_elems);
388  convert_bf16_tensor_to_buf(v, scratch_v, kv_elems);
389 
390  // Use exact version to avoid fast exp approximation error accumulating
391  // with BF16 precision loss.
392  attention_forward_causal_head_major_gqa_exact(scratch_q, scratch_k, scratch_v,
393  scores, output,
394  num_heads, num_kv_heads,
395  num_tokens, head_dim,
396  aligned_head_dim, aligned_context_window);
397  /* No free - caller owns scratch buffers */
398 }
void attention_forward_causal_head_major_gqa_exact(const float *q, const float *k, const float *v, float *scores, float *output, int num_heads, int num_kv_heads, int num_tokens, int head_dim, int aligned_head_dim, int aligned_context_window)

References attention_forward_causal_head_major_gqa_exact(), and convert_bf16_tensor_to_buf().

◆ attention_forward_causal_head_major_gqa_exact()

void attention_forward_causal_head_major_gqa_exact ( const float *  q,
const float *  k,
const float *  v,
float *  scores,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  aligned_context_window 
)

GQA causal attention forward (exact version using stdlib expf)

Test:

test_attention.py::TestAttentionForward::test_gqa_exact

bf16/test_attention_bf16.py::TestAttentionBF16::test_bf16_gqa

Uses standard library expf for numerical accuracy reference. Used by BF16 wrapper to avoid approximation error accumulation.

After changes: make test

Definition at line 294 of file attention_kernels.c.

305 {
306  const float scale = 1.0f / sqrtf((float)head_dim);
307 
308  for (int h = 0; h < num_heads; ++h) {
309  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
310  for (int i = 0; i < num_tokens; ++i) {
311  for (int j = 0; j <= i; ++j) {
312  float dot = 0.0f;
313  size_t base_q = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
314  size_t base_k = qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
315 
316  for (int d = 0; d < head_dim; ++d) {
317  dot += q[base_q + d] * k[base_k + d];
318  }
319 
320  scores[score_index(h, i, j, aligned_context_window)] = dot * scale;
321  }
322 
323  for (int j = i + 1; j < num_tokens; ++j) {
324  scores[score_index(h, i, j, aligned_context_window)] = 0.0f;
325  }
326  }
327  }
328 
329  // Use exact softmax with standard library expf
331  num_heads,
332  num_tokens,
333  aligned_context_window);
334 
335  for (int h = 0; h < num_heads; ++h) {
336  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
337  for (int i = 0; i < num_tokens; ++i) {
338  size_t out_base = qkv_index(h, i, 0, num_tokens, aligned_head_dim);
339  for (int d = 0; d < aligned_head_dim; ++d) {
340  output[out_base + d] = 0.0f;
341  }
342 
343  for (int j = 0; j <= i; ++j) {
344  float w = scores[score_index(h, i, j, aligned_context_window)];
345  size_t v_base = qkv_index(kv_head, j, 0, num_tokens, aligned_head_dim);
346 
347  for (int d = 0; d < head_dim; ++d) {
348  output[out_base + d] += w * v[v_base + d];
349  }
350  }
351  }
352  }
353 }

References causal_softmax_head_major_exact(), qkv_index(), and score_index().

Referenced by attention_forward_causal_head_major_gqa_bf16().

◆ attention_forward_causal_head_major_gqa_flash()

void attention_forward_causal_head_major_gqa_flash ( const float *  q,
const float *  k,
const float *  v,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim 
)

Flash attention forward for GQA (prefill, no score materialization)

Test:

test_flash_attention.py::TestFlashAttention::test_flash_forward

test_flash_attention.py::TestFlashAttention::test_flash_vs_score_matrix

test_flash_attention.py::TestFlashAttention::test_flash_gqa

test_attention.py::TestAttentionForward::test_flash_forward

Online softmax with streaming KV. O(N) memory instead of O(N^2). For prefill: all tokens attend to previous tokens.

After changes: make test && make llamacpp-parity-full

Definition at line 800 of file attention_kernels.c.

809 {
810  if (!q || !k || !v || !output) {
811  return;
812  }
813  if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
814  return;
815  }
816 
817  const float scale = 1.0f / sqrtf((float)head_dim);
818  const int T = num_tokens;
819 
820  // Select SIMD implementation based on compile-time CPU features
821 #if defined(__AVX512F__)
822  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx512
823 #elif defined(__AVX2__)
824  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx2
825 #elif defined(__AVX__)
826  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx
827 #else
828  #define FLASH_QUERY_IMPL attention_flash_query_causal
829 #endif
830 
831  for (int h = 0; h < num_heads; ++h) {
832  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
833  const float *k_head = k + (size_t)kv_head * (size_t)T * (size_t)aligned_head_dim;
834  const float *v_head = v + (size_t)kv_head * (size_t)T * (size_t)aligned_head_dim;
835 
836  for (int i = 0; i < T; ++i) {
837  const float *q_vec = q + qkv_index(h, i, 0, T, aligned_head_dim);
838  float *out_vec = output + qkv_index(h, i, 0, T, aligned_head_dim);
839  FLASH_QUERY_IMPL(q_vec, k_head, v_head,
840  /*kv_tokens=*/i + 1,
841  head_dim, aligned_head_dim,
842  scale, out_vec);
843  }
844  }
845 
846 #undef FLASH_QUERY_IMPL
847 }
#define FLASH_QUERY_IMPL

References FLASH_QUERY_IMPL, and qkv_index().

Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), ck_layer_forward_rmsnorm_swiglu_ref(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), and qwen2_0_5b_decode_layer_9_prefill().

◆ attention_forward_causal_head_major_gqa_flash_strided()

void attention_forward_causal_head_major_gqa_flash_strided ( const float *  q,
const float *  k,
const float *  v,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  kv_stride_tokens 
)

Flash attention forward with custom KV stride (for KV cache)

Test:

test_flash_attention.py::TestFlashAttention::test_flash_strided

test_kv_cache_attention.py::TestKVCacheAttention::test_flash_attention

Variant with configurable kv_stride_tokens for KV cache layouts where K/V may not be contiguous in memory.

After changes: make test

Definition at line 859 of file attention_kernels.c.

869 {
870  if (!q || !k || !v || !output) {
871  return;
872  }
873  if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
874  return;
875  }
876  if (kv_stride_tokens < num_tokens) {
877  return;
878  }
879 
880  const float scale = 1.0f / sqrtf((float)head_dim);
881  const int T = num_tokens;
882  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
883 
884  // Select SIMD implementation based on compile-time CPU features
885 #if defined(__AVX512F__)
886  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx512
887 #elif defined(__AVX2__)
888  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx2
889 #elif defined(__AVX__)
890  #define FLASH_QUERY_IMPL attention_flash_query_causal_avx
891 #else
892  #define FLASH_QUERY_IMPL attention_flash_query_causal
893 #endif
894 
895  for (int h = 0; h < num_heads; ++h) {
896  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
897  const float *k_head = k + (size_t)kv_head * kv_head_stride;
898  const float *v_head = v + (size_t)kv_head * kv_head_stride;
899 
900  for (int i = 0; i < T; ++i) {
901  const float *q_vec = q + qkv_index(h, i, 0, T, aligned_head_dim);
902  float *out_vec = output + qkv_index(h, i, 0, T, aligned_head_dim);
903  FLASH_QUERY_IMPL(q_vec, k_head, v_head,
904  /*kv_tokens=*/i + 1,
905  head_dim, aligned_head_dim,
906  scale, out_vec);
907  }
908  }
909 
910 #undef FLASH_QUERY_IMPL
911 }

References FLASH_QUERY_IMPL, and qkv_index().

Referenced by mega_fused_attention_prefill(), mega_fused_attention_prefill_q8_0(), model_layer_0_prefill(), model_layer_10_prefill(), model_layer_11_prefill(), model_layer_12_prefill(), model_layer_13_prefill(), model_layer_14_prefill(), model_layer_15_prefill(), model_layer_16_prefill(), model_layer_17_prefill(), model_layer_18_prefill(), model_layer_19_prefill(), model_layer_1_prefill(), model_layer_20_prefill(), model_layer_21_prefill(), model_layer_22_prefill(), model_layer_23_prefill(), model_layer_2_prefill(), model_layer_3_prefill(), model_layer_4_prefill(), model_layer_5_prefill(), model_layer_6_prefill(), model_layer_7_prefill(), model_layer_8_prefill(), model_layer_9_prefill(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), and qwen2_0_5b_decode_layer_9_prefill().

◆ attention_forward_causal_head_major_gqa_flash_strided_sliding()

void attention_forward_causal_head_major_gqa_flash_strided_sliding ( const float *  q,
const float *  k,
const float *  v,
float *  output,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  kv_stride_tokens,
int  sliding_window 
)

Flash attention forward with sliding window (prefill)

Test:
test_attention.py::TestAttentionForward::test_sliding_window_prefill

Sliding-window attention for prefill: each token attends to the last W tokens. When sliding_window <= 0, behaves like regular causal attention.

After changes: make test

Definition at line 1316 of file attention_kernels.c.

1328 {
1329  if (!q || !k || !v || !output) {
1330  return;
1331  }
1332  if (num_heads <= 0 || num_kv_heads <= 0 || num_tokens <= 0) {
1333  return;
1334  }
1335  if (kv_stride_tokens < num_tokens) {
1336  return;
1337  }
1338 
1339  const float scale = 1.0f / sqrtf((float)head_dim);
1340  const int T = num_tokens;
1341  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
1342 
1343 #if defined(__AVX512F__)
1344  #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx512
1345 #elif defined(__AVX2__)
1346  #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx2
1347 #elif defined(__AVX__)
1348  #define SLIDING_FLASH_IMPL attention_flash_query_sliding_avx
1349 #else
1350  #define SLIDING_FLASH_IMPL attention_flash_query_sliding
1351 #endif
1352 
1353  for (int h = 0; h < num_heads; ++h) {
1354  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1355  const float *k_head = k + (size_t)kv_head * kv_head_stride;
1356  const float *v_head = v + (size_t)kv_head * kv_head_stride;
1357 
1358  for (int i = 0; i < T; ++i) {
1359  const float *q_vec = q + qkv_index(h, i, 0, T, aligned_head_dim);
1360  float *out_vec = output + qkv_index(h, i, 0, T, aligned_head_dim);
1361  SLIDING_FLASH_IMPL(q_vec, k_head, v_head,
1362  /*query_pos=*/i,
1363  /*kv_tokens=*/T,
1364  head_dim, aligned_head_dim,
1365  scale, out_vec,
1366  sliding_window);
1367  }
1368  }
1369 
1370 #undef SLIDING_FLASH_IMPL
1371 }
#define SLIDING_FLASH_IMPL

References qkv_index(), and SLIDING_FLASH_IMPL.

◆ attention_forward_decode_head_major_gqa_flash()

void attention_forward_decode_head_major_gqa_flash ( const float *  q_token,
const float *  k_cache,
const float *  v_cache,
float *  out_token,
int  num_heads,
int  num_kv_heads,
int  kv_tokens,
int  cache_capacity,
int  head_dim,
int  aligned_head_dim 
)

Flash attention decode (single token attends to KV cache)

Test:

test_flash_attention.py::TestFlashAttention::test_flash_decode

test_kv_cache_attention.py::TestKVCacheAttention::test_flash_decode

test_fused_attention_decode.py::TestFusedAttentionDecode::test_flash_decode

test_attention.py::TestAttentionForward::test_flash_decode

Single query token attends to kv_tokens in KV cache. Uses true flash attention from attention_flash_true.c.

After changes: make test && make llamacpp-parity-full

Definition at line 1467 of file attention_kernels.c.

1477 {
1478  if (!q_token || !k_cache || !v_cache || !out_token) {
1479  return;
1480  }
1481  if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
1482  return;
1483  }
1484  if (kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
1485  return;
1486  }
1487 
1488  const float scale = 1.0f / sqrtf((float)head_dim);
1489  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
1490 
1491  for (int h = 0; h < num_heads; ++h) {
1492  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1493  const float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
1494  const float *k_head = k_cache + (size_t)kv_head * head_stride;
1495  const float *v_head = v_cache + (size_t)kv_head * head_stride;
1496  float *out_head = out_token + (size_t)h * (size_t)aligned_head_dim;
1497 
1498  attention_flash_decode(out_head,
1499  q_head,
1500  k_head,
1501  v_head,
1502  1,
1503  kv_tokens,
1504  1,
1505  aligned_head_dim,
1506  scale);
1507  }
1508 }
void attention_flash_decode(float *out, const float *q, const float *k, const float *v, int T_q, int T_k, int H, int D_h, float scale)
Main flash attention function with SIMD dispatch.

References attention_flash_decode().

Referenced by model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().

◆ attention_forward_decode_head_major_gqa_flash_sliding()

void attention_forward_decode_head_major_gqa_flash_sliding ( const float *  q_token,
const float *  k_cache,
const float *  v_cache,
float *  out_token,
int  num_heads,
int  num_kv_heads,
int  kv_tokens,
int  cache_capacity,
int  head_dim,
int  aligned_head_dim,
int  sliding_window 
)

Flash attention decode with sliding window

Test:
test_attention.py::TestAttentionForward::test_sliding_window_decode

Single query token attends to the last W tokens in the KV cache. For decode: effective_kv_tokens = min(kv_tokens, sliding_window)

After changes: make test

Definition at line 1382 of file attention_kernels.c.

1394 {
1395  if (!q_token || !k_cache || !v_cache || !out_token) {
1396  return;
1397  }
1398  if (num_heads <= 0 || num_kv_heads <= 0 || cache_capacity <= 0) {
1399  return;
1400  }
1401  if (kv_tokens <= 0 || kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
1402  return;
1403  }
1404 
1405  const float scale = 1.0f / sqrtf((float)head_dim);
1406  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
1407 
1408  // Compute effective KV tokens based on sliding window
1409  int effective_kv_tokens = kv_tokens;
1410  if (sliding_window > 0 && sliding_window < kv_tokens) {
1411  effective_kv_tokens = sliding_window;
1412  }
1413 
1414  // Guard against empty window (shouldn't happen with kv_tokens >= 1)
1415  if (effective_kv_tokens <= 0) {
1416  return;
1417  }
1418 
1419  // Offset to start reading from the last effective_kv_tokens entries
1420  int kv_start_offset = kv_tokens - effective_kv_tokens;
1421 
1422 #if defined(__AVX512F__)
1423  #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx512
1424 #elif defined(__AVX2__)
1425  #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx2
1426 #elif defined(__AVX__)
1427  #define SLIDING_DECODE_IMPL attention_flash_query_sliding_avx
1428 #else
1429  #define SLIDING_DECODE_IMPL attention_flash_query_sliding
1430 #endif
1431 
1432  for (int h = 0; h < num_heads; ++h) {
1433  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1434  const float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
1435  // Offset K/V pointer to start from the first token in the sliding window
1436  const float *k_head = k_cache + (size_t)kv_head * head_stride
1437  + (size_t)kv_start_offset * (size_t)aligned_head_dim;
1438  const float *v_head = v_cache + (size_t)kv_head * head_stride
1439  + (size_t)kv_start_offset * (size_t)aligned_head_dim;
1440  float *out_head = out_token + (size_t)h * (size_t)aligned_head_dim;
1441 
1442  // Use query_pos relative to the windowed KV (last token = effective_kv_tokens - 1)
1443  // sliding_window = 0 since we've already windowed via K/V pointer offset
1444  SLIDING_DECODE_IMPL(q_head, k_head, v_head,
1445  /*query_pos=*/effective_kv_tokens - 1,
1446  /*kv_tokens=*/effective_kv_tokens,
1447  head_dim, aligned_head_dim,
1448  scale, out_head,
1449  /*sliding_window=*/0);
1450  }
1451 
1452 #undef SLIDING_DECODE_IMPL
1453 }
#define SLIDING_DECODE_IMPL

References SLIDING_DECODE_IMPL.

◆ attention_forward_decode_head_major_gqa_regular()

void attention_forward_decode_head_major_gqa_regular ( const float *  q_token,
const float *  k_cache,
const float *  v_cache,
float *  out_token,
int  num_heads,
int  num_kv_heads,
int  kv_tokens,
int  cache_capacity,
int  head_dim,
int  aligned_head_dim 
)

WARNING: This is NOT true flash attention!

This function is named "flash" but implements regular attention with O(n) complexity. It's kept for reference and as a fallback.

TRUE flash attention is implemented in attention_flash_true.c

Test:

test_kv_cache_attention.py::TestKVCacheAttention::test_regular_decode

test_attention.py::TestAttentionForward::test_regular_decode

Regular attention decode (score-matrix version) for fallback.

After changes: make test

Definition at line 1524 of file attention_kernels.c.

1534 {
1535  if (!q_token || !k_cache || !v_cache || !out_token) {
1536  return;
1537  }
1538  if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
1539  return;
1540  }
1541  if (kv_tokens > cache_capacity) {
1542  return;
1543  }
1544 
1545  const float scale = 1.0f / sqrtf((float)head_dim);
1546  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
1547 
1548  // Select SIMD implementation based on compile-time CPU features
1549 #if defined(__AVX512F__)
1550  #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal_avx512
1551 #elif defined(__AVX2__)
1552  #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal_avx2
1553 #elif defined(__AVX__)
1554  #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal_avx
1555 #else
1556  #define FLASH_QUERY_IMPL_DECODE attention_flash_query_causal
1557 #endif
1558 
1559 #pragma omp parallel for schedule(static) if(num_heads > 1)
1560  for (int h = 0; h < num_heads; ++h) {
1561  int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
1562  const float *q_vec = q_token + (size_t)h * (size_t)aligned_head_dim;
1563  const float *k_head = k_cache + (size_t)kv_head * head_stride;
1564  const float *v_head = v_cache + (size_t)kv_head * head_stride;
1565  float *out_vec = out_token + (size_t)h * (size_t)aligned_head_dim;
1566 
1567  FLASH_QUERY_IMPL_DECODE(q_vec, k_head, v_head,
1568  kv_tokens, head_dim, aligned_head_dim,
1569  scale, out_vec);
1570  }
1571 
1572 #undef FLASH_QUERY_IMPL_DECODE
1573 }
#define FLASH_QUERY_IMPL_DECODE

References FLASH_QUERY_IMPL_DECODE.

Referenced by ck_attention_flash_decode_wrapper(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().

◆ axpy_2d_f32()

void axpy_2d_f32 ( float *  Y,
const float *  X,
float  alpha,
int  num_tokens,
int  dim,
int  y_stride,
int  x_stride 
)

Batched AXPY for 2D tensors: Y[t,:] += alpha * X[t,:].

Parameters
YOutput tensor [num_tokens, dim]
XInput tensor [num_tokens, dim]
alphaScalar multiplier
num_tokensNumber of tokens
dimHidden dimension
y_strideStride between Y rows (for alignment)
x_strideStride between X rows

Definition at line 221 of file axpy_kernels.c.

228 {
229  if (!Y || !X || num_tokens <= 0 || dim <= 0) {
230  return;
231  }
232 
233  /* Default strides if not specified */
234  if (y_stride <= 0) y_stride = dim;
235  if (x_stride <= 0) x_stride = dim;
236 
237  for (int t = 0; t < num_tokens; t++) {
238  axpy_f32(Y + t * y_stride, X + t * x_stride, alpha, dim);
239  }
240 }
void axpy_f32(float *y, const float *x, float alpha, int n)
In-place AXPY: y += alpha * x.
Definition: axpy_kernels.c:54

References axpy_f32().

◆ axpy_f32()

void axpy_f32 ( float *  y,
const float *  x,
float  alpha,
int  n 
)

In-place AXPY: y += alpha * x.

Test:

test_axpy.py::TestAXPY::test_axpy_f32

test_axpy.py::TestAXPY::test_axpy_vs_naive

In-place scaled vector addition: y += alpha * x BLAS-like axpy operation.

After changes: make test

Definition at line 54 of file axpy_kernels.c.

58 {
59  if (!y || !x || n <= 0) {
60  return;
61  }
62 
63  int i = 0;
64 
65 #ifdef __AVX512F__
66  __m512 valpha = _mm512_set1_ps(alpha);
67  for (; i + 16 <= n; i += 16) {
68  __m512 vy = _mm512_loadu_ps(&y[i]);
69  __m512 vx = _mm512_loadu_ps(&x[i]);
70  vy = _mm512_fmadd_ps(vx, valpha, vy); /* y = y + alpha * x */
71  _mm512_storeu_ps(&y[i], vy);
72  }
73 #endif
74 
75 #ifdef __AVX2__
76  __m256 valpha256 = _mm256_set1_ps(alpha);
77  for (; i + 8 <= n; i += 8) {
78  __m256 vy = _mm256_loadu_ps(&y[i]);
79  __m256 vx = _mm256_loadu_ps(&x[i]);
80  vy = _mm256_fmadd_ps(vx, valpha256, vy);
81  _mm256_storeu_ps(&y[i], vy);
82  }
83 #endif
84 
85  /* Scalar remainder */
86  for (; i < n; i++) {
87  y[i] += alpha * x[i];
88  }
89 }

Referenced by axpy_2d_f32(), axpy_zero_f32(), moe_accumulate_expert_f32(), and weighted_sum_f32().

◆ axpy_zero_f32()

void axpy_zero_f32 ( float *  y,
const float *  x,
float  alpha,
int  n 
)

Zero output then accumulate: y = 0; y += alpha * x.

Parameters
yOutput vector [n], zeroed then accumulated
xInput vector [n]
alphaScalar multiplier
nVector length

Definition at line 188 of file axpy_kernels.c.

192 {
193  if (!y || n <= 0) {
194  return;
195  }
196 
197  memset(y, 0, n * sizeof(float));
198 
199  if (x) {
200  axpy_f32(y, x, alpha, n);
201  }
202 }

References axpy_f32().

◆ backward_causal_softmax_head_major()

void backward_causal_softmax_head_major ( float *  d_scores,
const float *  weights,
int  num_heads,
int  num_tokens,
int  aligned_context_window 
)

Definition at line 382 of file softmax_kernels.c.

387 {
388  int H = num_heads;
389  int T = num_tokens;
390 
391  for (int h = 0; h < H; ++h) {
392  for (int i = 0; i < T; ++i) {
393  int base = h * aligned_context_window * aligned_context_window
394  + i * aligned_context_window;
395  float *drow = &d_scores[base];
396  const float *wrow = &weights[base];
397  int len = i + 1;
398 
399 #if defined(__AVX512F__)
400  // Compute dot product (vectorized)
401  __m512 dot_vec = _mm512_setzero_ps();
402  int j = 0;
403  for (; j + 16 <= len; j += 16) {
404  __m512 w = _mm512_loadu_ps(&wrow[j]);
405  __m512 dw = _mm512_loadu_ps(&drow[j]);
406  dot_vec = _mm512_fmadd_ps(w, dw, dot_vec);
407  }
408  float dot_product = _mm512_reduce_add_ps(dot_vec);
409  for (; j < len; ++j) {
410  dot_product += wrow[j] * drow[j];
411  }
412 
413  // Compute gradient: d_scores = w * (dw - dot_product)
414  __m512 dot_broadcast = _mm512_set1_ps(dot_product);
415  j = 0;
416  for (; j + 16 <= len; j += 16) {
417  __m512 w = _mm512_loadu_ps(&wrow[j]);
418  __m512 dw = _mm512_loadu_ps(&drow[j]);
419  __m512 diff = _mm512_sub_ps(dw, dot_broadcast);
420  __m512 result = _mm512_mul_ps(w, diff);
421  _mm512_storeu_ps(&drow[j], result);
422  }
423  for (; j < len; ++j) {
424  drow[j] = wrow[j] * (drow[j] - dot_product);
425  }
426 
427  // Zero out future tokens
428  __m512 zero = _mm512_setzero_ps();
429  for (; j + 16 <= T; j += 16) {
430  _mm512_storeu_ps(&drow[j], zero);
431  }
432  for (; j < T; ++j) {
433  drow[j] = 0.0f;
434  }
435 
436 #elif defined(__AVX__)
437  // Compute dot product (vectorized)
438  __m256 dot_vec = _mm256_setzero_ps();
439  int j = 0;
440  for (; j + 8 <= len; j += 8) {
441  __m256 w = _mm256_loadu_ps(&wrow[j]);
442  __m256 dw = _mm256_loadu_ps(&drow[j]);
443  // No FMA in AVX1: use mul + add
444  __m256 prod = _mm256_mul_ps(w, dw);
445  dot_vec = _mm256_add_ps(dot_vec, prod);
446  }
447  float dot_product = hsum256_ps_softmax(dot_vec);
448  for (; j < len; ++j) {
449  dot_product += wrow[j] * drow[j];
450  }
451 
452  // Compute gradient: d_scores = w * (dw - dot_product)
453  __m256 dot_broadcast = _mm256_set1_ps(dot_product);
454  j = 0;
455  for (; j + 8 <= len; j += 8) {
456  __m256 w = _mm256_loadu_ps(&wrow[j]);
457  __m256 dw = _mm256_loadu_ps(&drow[j]);
458  __m256 diff = _mm256_sub_ps(dw, dot_broadcast);
459  __m256 result = _mm256_mul_ps(w, diff);
460  _mm256_storeu_ps(&drow[j], result);
461  }
462  for (; j < len; ++j) {
463  drow[j] = wrow[j] * (drow[j] - dot_product);
464  }
465 
466  // Zero out future tokens
467  __m256 zero = _mm256_setzero_ps();
468  for (; j + 8 <= T; j += 8) {
469  _mm256_storeu_ps(&drow[j], zero);
470  }
471  for (; j < T; ++j) {
472  drow[j] = 0.0f;
473  }
474 
475 #else
476  // Scalar fallback
477  float dot_product = 0.0f;
478  for (int j = 0; j < len; ++j) {
479  dot_product += wrow[j] * drow[j];
480  }
481 
482  for (int j = 0; j < len; ++j) {
483  drow[j] = wrow[j] * (drow[j] - dot_product);
484  }
485 
486  for (int j = len; j < T; ++j) {
487  drow[j] = 0.0f;
488  }
489 #endif
490  }
491  }
492 }

Referenced by backward_causal_softmax_head_major_bf16().

◆ backward_causal_softmax_head_major_bf16()

void backward_causal_softmax_head_major_bf16 ( uint16_t *  d_scores,
const uint16_t *  weights,
int  num_heads,
int  num_tokens,
int  aligned_context_window,
float *  scratch_d_scores,
float *  scratch_weights 
)

Definition at line 53 of file softmax_kernels_bf16.c.

60 {
61  if (!d_scores || !weights || num_heads <= 0 || num_tokens <= 0 || aligned_context_window <= 0) return;
62  if (!scratch_d_scores || !scratch_weights) return;
63 
64  const size_t total = (size_t)num_heads *
65  (size_t)aligned_context_window *
66  (size_t)aligned_context_window;
67 
68  bf16_tensor_to_float(d_scores, scratch_d_scores, total);
69  bf16_tensor_to_float(weights, scratch_weights, total);
70  backward_causal_softmax_head_major(scratch_d_scores, scratch_weights, num_heads, num_tokens, aligned_context_window);
71  float_tensor_to_bf16(scratch_d_scores, d_scores, total);
72 }
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
Definition: bf16_utils.h:271
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
Definition: bf16_utils.h:250
void backward_causal_softmax_head_major(float *d_scores, const float *weights, int num_heads, int num_tokens, int aligned_context_window)

References backward_causal_softmax_head_major(), bf16_tensor_to_float(), and float_tensor_to_bf16().

◆ causal_softmax_head_major()

void causal_softmax_head_major ( float *  scores,
int  num_heads,
int  num_tokens,
int  aligned_context_window 
)

Causal softmax (in-place, row-wise)

Test:

test_softmax.py::TestSoftmaxForward::test_causal_softmax

test_softmax.py::TestSoftmaxForward::test_causal_vs_softmax

test_attention.py::TestAttentionForward::test_softmax_correctness

Applies causal mask (j > i => 0) and softmax to scores matrix. In-place on [num_heads, T, T] scores matrix.

After changes: make test && make llamacpp-parity-full

Definition at line 144 of file softmax_kernels.c.

148 {
149  for (int h = 0; h < num_heads; ++h) {
150  for (int i = 0; i < num_tokens; ++i) {
151  int base = h * aligned_context_window * aligned_context_window
152  + i * aligned_context_window;
153  float *row = &scores[base];
154  int len = i + 1; // Number of valid elements (0..i inclusive)
155 
156 #if defined(__AVX512F__)
157  // Find max (vectorized)
158  __m512 max_vec = _mm512_set1_ps(-INFINITY);
159  int j = 0;
160  for (; j + 16 <= len; j += 16) {
161  __m512 v = _mm512_loadu_ps(&row[j]);
162  max_vec = _mm512_max_ps(max_vec, v);
163  }
164  float max_val = _mm512_reduce_max_ps(max_vec);
165  for (; j < len; ++j) {
166  if (row[j] > max_val) max_val = row[j];
167  }
168 
169  // Compute exp and sum (vectorized)
170  __m512 max_broadcast = _mm512_set1_ps(max_val);
171  __m512 sum_vec = _mm512_setzero_ps();
172  j = 0;
173  for (; j + 16 <= len; j += 16) {
174  __m512 v = _mm512_loadu_ps(&row[j]);
175  __m512 e = exp512_approx(_mm512_sub_ps(v, max_broadcast));
176  _mm512_storeu_ps(&row[j], e);
177  sum_vec = _mm512_add_ps(sum_vec, e);
178  }
179  float sum = _mm512_reduce_add_ps(sum_vec);
180  for (; j < len; ++j) {
181  float e = expf(row[j] - max_val);
182  row[j] = e;
183  sum += e;
184  }
185 
186  // Normalize (vectorized)
187  float inv_sum = 1.0f / sum;
188  __m512 inv_sum_vec = _mm512_set1_ps(inv_sum);
189  j = 0;
190  for (; j + 16 <= len; j += 16) {
191  __m512 v = _mm512_loadu_ps(&row[j]);
192  _mm512_storeu_ps(&row[j], _mm512_mul_ps(v, inv_sum_vec));
193  }
194  for (; j < len; ++j) {
195  row[j] *= inv_sum;
196  }
197 
198  // Zero out future tokens (vectorized)
199  __m512 zero = _mm512_setzero_ps();
200  for (; j + 16 <= num_tokens; j += 16) {
201  _mm512_storeu_ps(&row[j], zero);
202  }
203  for (; j < num_tokens; ++j) {
204  row[j] = 0.0f;
205  }
206 
207 #elif defined(__AVX2__)
208  // AVX2: Find max (vectorized)
209  __m256 max_vec = _mm256_set1_ps(-INFINITY);
210  int j = 0;
211  for (; j + 8 <= len; j += 8) {
212  __m256 v = _mm256_loadu_ps(&row[j]);
213  max_vec = _mm256_max_ps(max_vec, v);
214  }
215  float max_val = hmax256_ps(max_vec);
216  for (; j < len; ++j) {
217  if (row[j] > max_val) max_val = row[j];
218  }
219 
220  // Compute exp and sum (vectorized with fast exp)
221  __m256 max_broadcast = _mm256_set1_ps(max_val);
222  __m256 sum_vec = _mm256_setzero_ps();
223  j = 0;
224  for (; j + 8 <= len; j += 8) {
225  __m256 v = _mm256_loadu_ps(&row[j]);
226  __m256 e = exp256_approx(_mm256_sub_ps(v, max_broadcast));
227  _mm256_storeu_ps(&row[j], e);
228  sum_vec = _mm256_add_ps(sum_vec, e);
229  }
230  float sum = hsum256_ps_softmax(sum_vec);
231  for (; j < len; ++j) {
232  float e = expf(row[j] - max_val);
233  row[j] = e;
234  sum += e;
235  }
236 
237  // Normalize (vectorized)
238  float inv_sum = 1.0f / sum;
239  __m256 inv_sum_vec = _mm256_set1_ps(inv_sum);
240  j = 0;
241  for (; j + 8 <= len; j += 8) {
242  __m256 v = _mm256_loadu_ps(&row[j]);
243  _mm256_storeu_ps(&row[j], _mm256_mul_ps(v, inv_sum_vec));
244  }
245  for (; j < len; ++j) {
246  row[j] *= inv_sum;
247  }
248 
249  // Zero out future tokens (vectorized)
250  __m256 zero = _mm256_setzero_ps();
251  for (; j + 8 <= num_tokens; j += 8) {
252  _mm256_storeu_ps(&row[j], zero);
253  }
254  for (; j < num_tokens; ++j) {
255  row[j] = 0.0f;
256  }
257 
258 #elif defined(__AVX__)
259  // AVX1: vectorized max/sum/normalize, scalar exp
260  __m256 max_vec = _mm256_set1_ps(-INFINITY);
261  int j = 0;
262  for (; j + 8 <= len; j += 8) {
263  __m256 v = _mm256_loadu_ps(&row[j]);
264  max_vec = _mm256_max_ps(max_vec, v);
265  }
266  float max_val = hmax256_ps(max_vec);
267  for (; j < len; ++j) {
268  if (row[j] > max_val) max_val = row[j];
269  }
270 
271  // Compute exp and sum (scalar exp, no fast approx for AVX1)
272  float sum = 0.0f;
273  for (j = 0; j < len; ++j) {
274  float e = expf(row[j] - max_val);
275  row[j] = e;
276  sum += e;
277  }
278 
279  // Normalize (vectorized)
280  float inv_sum = 1.0f / sum;
281  __m256 inv_sum_vec = _mm256_set1_ps(inv_sum);
282  j = 0;
283  for (; j + 8 <= len; j += 8) {
284  __m256 v = _mm256_loadu_ps(&row[j]);
285  _mm256_storeu_ps(&row[j], _mm256_mul_ps(v, inv_sum_vec));
286  }
287  for (; j < len; ++j) {
288  row[j] *= inv_sum;
289  }
290 
291  // Zero out future tokens (vectorized)
292  __m256 zero = _mm256_setzero_ps();
293  for (; j + 8 <= num_tokens; j += 8) {
294  _mm256_storeu_ps(&row[j], zero);
295  }
296  for (; j < num_tokens; ++j) {
297  row[j] = 0.0f;
298  }
299 
300 #else
301  // Scalar fallback
302  float max_val = row[0];
303  for (int j = 1; j < len; ++j) {
304  if (row[j] > max_val) max_val = row[j];
305  }
306 
307  float sum = 0.0f;
308  for (int j = 0; j < len; ++j) {
309  float e = expf(row[j] - max_val);
310  row[j] = e;
311  sum += e;
312  }
313 
314  float inv_sum = 1.0f / sum;
315  for (int j = 0; j < len; ++j) {
316  row[j] *= inv_sum;
317  }
318 
319  for (int j = len; j < num_tokens; ++j) {
320  row[j] = 0.0f;
321  }
322 #endif
323  }
324  }
325 }

Referenced by attention_forward_causal_head_major(), attention_forward_causal_head_major_gqa(), and causal_softmax_head_major_bf16().

◆ causal_softmax_head_major_bf16()

void causal_softmax_head_major_bf16 ( uint16_t *  scores,
int  num_heads,
int  num_tokens,
int  aligned_context_window,
float *  scratch 
)

Definition at line 31 of file softmax_kernels_bf16.c.

36 {
37  if (!scores || num_heads <= 0 || num_tokens <= 0 || aligned_context_window <= 0) return;
38  if (!scratch) return;
39 
40  const size_t total = (size_t)num_heads *
41  (size_t)aligned_context_window *
42  (size_t)aligned_context_window;
43 
44  bf16_tensor_to_float(scores, scratch, total);
45  causal_softmax_head_major(scratch, num_heads, num_tokens, aligned_context_window);
46  float_tensor_to_bf16(scratch, scores, total);
47 }

References bf16_tensor_to_float(), causal_softmax_head_major(), and float_tensor_to_bf16().

◆ causal_softmax_head_major_exact()

void causal_softmax_head_major_exact ( float *  scores,
int  num_heads,
int  num_tokens,
int  aligned_context_window 
)

Causal softmax (exact version using stdlib expf)

Test:

test_softmax.py::TestSoftmaxForward::test_causal_softmax_exact

test_softmax.py::TestSoftmaxForward::test_exact_vs_fast

Exact causal softmax using standard library expf for numerical accuracy reference.

After changes: make test

Definition at line 339 of file softmax_kernels.c.

343 {
344  for (int h = 0; h < num_heads; ++h) {
345  for (int i = 0; i < num_tokens; ++i) {
346  int base = h * aligned_context_window * aligned_context_window
347  + i * aligned_context_window;
348  float *row = &scores[base];
349  int len = i + 1;
350 
351  // Find max
352  float max_val = -INFINITY;
353  for (int j = 0; j < len; ++j) {
354  if (row[j] > max_val) max_val = row[j];
355  }
356 
357  // Compute exp and sum using standard library expf
358  float sum = 0.0f;
359  for (int j = 0; j < len; ++j) {
360  float e = expf(row[j] - max_val);
361  row[j] = e;
362  sum += e;
363  }
364 
365  // Normalize
366  float inv_sum = 1.0f / sum;
367  for (int j = 0; j < len; ++j) {
368  row[j] *= inv_sum;
369  }
370 
371  // Zero out future tokens
372  for (int j = len; j < num_tokens; ++j) {
373  row[j] = 0.0f;
374  }
375  }
376  }
377 }

Referenced by attention_forward_causal_head_major_exact(), and attention_forward_causal_head_major_gqa_exact().

◆ ck_attention_flash_decode_wrapper()

void ck_attention_flash_decode_wrapper ( const float *  q_token,
const float *  k_cache,
const float *  v_cache,
float *  out_token,
int  num_heads,
int  num_kv_heads,
int  kv_tokens,
int  cache_capacity,
int  head_dim,
int  aligned_head_dim 
)

Wrapper to call TRUE flash attention from orchestration layer.

Parameters
q_tokenQuery token [H, D_h]
k_cacheCached keys [T_k, H, D_h]
v_cacheCached values [T_k, H, D_h]
out_tokenOutput [H, D_h]
num_headsNumber of heads
num_kv_headsNumber of KV heads (for GQA)
kv_tokensNumber of tokens in KV cache
cache_capacityCache capacity
head_dimHead dimension
aligned_head_dimAligned head dimension

Definition at line 72 of file ckernel_orchestration.c.

83 {
84  if (!q_token || !k_cache || !v_cache || !out_token) {
85  return;
86  }
87  if (num_heads <= 0 || num_kv_heads <= 0 || kv_tokens <= 0 || cache_capacity <= 0) {
88  return;
89  }
90  if (kv_tokens > cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
91  return;
92  }
93 
94  static int use_strict = -1;
95  if (use_strict < 0) {
96  const char *env = getenv("CK_FLASH_ATTN_STRICT");
97  use_strict = (env && env[0] && env[0] != '0') ? 1 : 0;
98  }
99 
100  if (use_strict) {
102  k_cache,
103  v_cache,
104  out_token,
105  num_heads,
106  num_kv_heads,
107  kv_tokens,
108  cache_capacity,
109  head_dim,
110  aligned_head_dim);
111  return;
112  }
113 
114  // Scale factor: 1/sqrt(head_dim)
115  const float scale = 1.0f / sqrtf((float)head_dim);
116  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
117 
118 #pragma omp parallel for schedule(static) if(num_heads > 1)
119  for (int h = 0; h < num_heads; ++h) {
120  const int kv_head = (int)((long long)h * (long long)num_kv_heads / (long long)num_heads);
121  const float *q_head = q_token + (size_t)h * (size_t)aligned_head_dim;
122  const float *k_head = k_cache + (size_t)kv_head * head_stride;
123  const float *v_head = v_cache + (size_t)kv_head * head_stride;
124  float *out_head = out_token + (size_t)h * (size_t)aligned_head_dim;
125 
126  // Use aligned_head_dim as D_h so per-token stride matches the cache layout.
127  attention_flash_decode(out_head,
128  q_head,
129  k_head,
130  v_head,
131  1,
132  kv_tokens,
133  1,
134  aligned_head_dim,
135  scale);
136  }
137 }
void attention_forward_decode_head_major_gqa_regular(const float *q_token, const float *k_cache, const float *v_cache, float *out_token, int num_heads, int num_kv_heads, int kv_tokens, int cache_capacity, int head_dim, int aligned_head_dim)
WARNING: This is NOT true flash attention!

References attention_flash_decode(), and attention_forward_decode_head_major_gqa_regular().

Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), and ck_layer_forward_rmsnorm_swiglu_decode_quant().

◆ ck_flash_attn_choose_tile_k()

int ck_flash_attn_choose_tile_k ( int  D_h)

Definition at line 108 of file attention_flash_true.c.

108  {
109  return ck_flash_attn_tile_k(D_h);
110 }
static int ck_flash_attn_tile_k(int D_h)

References ck_flash_attn_tile_k().

◆ ck_flash_attn_fast_exp_kind()

int ck_flash_attn_fast_exp_kind ( void  )

Definition at line 112 of file attention_flash_true.c.

112  {
113 #if CK_FLASH_ATTN_FAST_EXP
114 #if defined(__AVX512F__)
115  return 512;
116 #elif defined(__AVX__)
117  return 256;
118 #else
119  return 0;
120 #endif
121 #else
122  return 0;
123 #endif
124 }

◆ ck_gemm_nt_head_major_q5_0()

void ck_gemm_nt_head_major_q5_0 ( const float *  attn_out,
const void *  wo,
const float *  bias,
float *  output,
int  tokens,
int  embed_dim,
int  num_heads,
int  head_dim 
)

Output projection from head-major attention (auto-dispatch)

This replaces flatten_head_major() + ck_gemm_nt_quant() with a single strided-access kernel that reads head-major attention output directly.

Definition at line 328 of file gemm_head_major_output.c.

336 {
337 #if defined(__AVX__) && defined(__F16C__)
338  gemv_nt_q5_0_head_major_output_avx(output, attn_out, wo, bias,
339  tokens, embed_dim, num_heads, head_dim);
340 #else
341  gemv_nt_q5_0_head_major_output(output, attn_out, wo, bias,
342  tokens, embed_dim, num_heads, head_dim);
343 #endif
344 }
void gemv_nt_q5_0_head_major_output(float *output, const float *attn_out, const void *wo, const float *bias, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection reading head-major attention output (Q5_0 weights)

References gemv_nt_q5_0_head_major_output().

Referenced by mega_fused_attention_prefill().

◆ ck_gemm_nt_head_major_q8_0()

void ck_gemm_nt_head_major_q8_0 ( const float *  attn_out,
const void *  wo,
const float *  bias,
float *  output,
int  tokens,
int  embed_dim,
int  num_heads,
int  head_dim 
)

Output projection from head-major attention (Q8_0 weights)

Definition at line 353 of file gemm_head_major_output.c.

361 {
362  if (!output || !attn_out || !wo) return;
363  if (tokens <= 0 || embed_dim <= 0 || num_heads <= 0 || head_dim <= 0) return;
364 
365  const int blocks_per_head = head_dim / QK8_0;
366  const int blocks_per_row = embed_dim / QK8_0;
367  const block_q8_0 *weights = (const block_q8_0 *)wo;
368 
369  const size_t token_stride = head_dim;
370  const size_t head_stride = (size_t)tokens * token_stride;
371 
372  /* Initialize output */
373  if (bias) {
374  for (int t = 0; t < tokens; t++) {
375  float *out_row = output + (size_t)t * embed_dim;
376  for (int n = 0; n < embed_dim; n++) {
377  out_row[n] = bias[n];
378  }
379  }
380  } else {
381  memset(output, 0, (size_t)tokens * embed_dim * sizeof(float));
382  }
383 
384  /* Accumulate from each head */
385  for (int h = 0; h < num_heads; h++) {
386  const float *head_data = attn_out + (size_t)h * head_stride;
387  const int head_offset = h * blocks_per_head;
388 
389  for (int n_block = 0; n_block < blocks_per_head; n_block++) {
390  for (int n = 0; n < embed_dim; n++) {
391  const block_q8_0 *w_row = weights + (size_t)n * blocks_per_row + head_offset + n_block;
392  const float d = CK_FP16_TO_FP32(w_row->d);
393 
394  for (int t = 0; t < tokens; t++) {
395  const float *token_vec = head_data + (size_t)t * token_stride + (size_t)n_block * QK8_0;
396  float sum = 0.0f;
397 
398  for (int j = 0; j < QK8_0; j++) {
399  sum += d * (float)w_row->qs[j] * token_vec[j];
400  }
401 
402  output[(size_t)t * embed_dim + n] += sum;
403  }
404  }
405  }
406  }
407 }
#define CK_FP16_TO_FP32(x)
#define QK8_0
int8_t qs[32]

References CK_FP16_TO_FP32, block_q8_0::d, QK8_0, and block_q8_0::qs.

Referenced by mega_fused_attention_prefill().

◆ ck_get_num_threads()

int ck_get_num_threads ( void  )

Definition at line 178 of file ckernel_strict.c.

179 {
180  // Auto-initialize if not set
181  if (!g_threads_initialized) {
182  ck_set_num_threads(0); // Auto-detect
183  }
184  return g_num_threads;
185 }
void ck_set_num_threads(int num_threads)
static int g_num_threads
static int g_threads_initialized

References ck_set_num_threads(), g_num_threads, and g_threads_initialized.

Referenced by gemm_blocked_serial().

◆ ck_get_physical_cores()

int ck_get_physical_cores ( void  )

Definition at line 62 of file ckernel_strict.c.

63 {
64  int physical_cores = 0;
65  int logical_cores = (int)sysconf(_SC_NPROCESSORS_ONLN);
66  if (logical_cores <= 0) {
67  logical_cores = 1;
68  }
69 
70  // Read from /proc/cpuinfo (Linux) and count unique (physical id, core id) pairs.
71  FILE *f = fopen("/proc/cpuinfo", "r");
72  if (f) {
73  char line[256];
74  int physical_id = -1;
75  int core_id = -1;
76 
77  struct {
78  int physical_id;
79  int core_id;
80  } seen[8192];
81  int seen_count = 0;
82 
83  const int seen_cap = (int)(sizeof(seen) / sizeof(seen[0]));
84 
85  // Helper: add (pid,cid) to set if not present.
86  #define CK_ADD_PAIR(pid, cid) \
87  do { \
88  if ((pid) >= 0 && (cid) >= 0) { \
89  int exists = 0; \
90  for (int ii = 0; ii < seen_count; ++ii) { \
91  if (seen[ii].physical_id == (pid) && \
92  seen[ii].core_id == (cid)) { \
93  exists = 1; \
94  break; \
95  } \
96  } \
97  if (!exists && seen_count < seen_cap) { \
98  seen[seen_count].physical_id = (pid); \
99  seen[seen_count].core_id = (cid); \
100  ++seen_count; \
101  } \
102  } \
103  } while (0)
104 
105  while (fgets(line, sizeof(line), f)) {
106  int val;
107 
108  // Blank line separates processor blocks.
109  if (line[0] == '\n' || line[0] == '\0') {
110  CK_ADD_PAIR(physical_id, core_id);
111  physical_id = -1;
112  core_id = -1;
113  continue;
114  }
115 
116  if (sscanf(line, "physical id : %d", &val) == 1) {
117  physical_id = val;
118  continue;
119  }
120  if (sscanf(line, "core id : %d", &val) == 1) {
121  core_id = val;
122  continue;
123  }
124  }
125  fclose(f);
126 
127  // Handle file without trailing blank line.
128  CK_ADD_PAIR(physical_id, core_id);
129 
130  #undef CK_ADD_PAIR
131 
132  physical_cores = seen_count;
133  }
134 
135  // If we couldn't reliably detect physical cores (common in containers),
136  // fall back to logical CPUs instead of incorrectly forcing single-thread execution.
137  if (physical_cores <= 1 && logical_cores > 1) {
138  return logical_cores;
139  }
140 
141  if (physical_cores > 1) {
142  return physical_cores;
143  }
144 
145  return logical_cores;
146 }
#define CK_ADD_PAIR(pid, cid)

References CK_ADD_PAIR.

◆ ck_set_num_threads()

void ck_set_num_threads ( int  num_threads)

Definition at line 148 of file ckernel_strict.c.

149 {
150  // 0 = auto-detect
151  if (num_threads <= 0) {
152  // Prefer explicit env controls when present:
153  // - CK_NUM_THREADS: engine-level override
154  // - OMP_NUM_THREADS: standard OpenMP control (set by `ck run --threads`)
155  int env_threads = ck_parse_env_int("CK_NUM_THREADS");
156  if (env_threads <= 0) {
157  env_threads = ck_parse_env_int("OMP_NUM_THREADS");
158  }
159  num_threads = env_threads > 0 ? env_threads : ck_get_physical_cores();
160  }
161 
162  g_num_threads = num_threads;
164 
165 #ifdef _OPENMP
166  omp_set_dynamic(0); // Disable dynamic adjustment
167  omp_set_num_threads(num_threads);
168 #endif
169 
170 #if defined(USE_MKL)
171  mkl_set_num_threads(num_threads);
172 #endif
173 
174  fprintf(stderr, "[CK] Set %d threads (auto=%d)\n",
175  num_threads, ck_get_physical_cores());
176 }
static int ck_parse_env_int(const char *name)
int ck_get_physical_cores(void)

References ck_get_physical_cores(), ck_parse_env_int(), g_num_threads, and g_threads_initialized.

Referenced by ck_get_num_threads().

◆ ck_set_strict_parity()

void ck_set_strict_parity ( int  enabled)

Definition at line 22 of file ckernel_strict.c.

23 {
24  ck_strict_parity = enabled ? 1 : 0;
25 #ifdef _OPENMP
26  if (ck_strict_parity) {
27  omp_set_dynamic(0);
28  omp_set_num_threads(1);
29  }
30 #endif
31 }
static int ck_strict_parity

References ck_strict_parity.

◆ ck_strict_parity_enabled()

◆ ckernel_backend_native()

CKMathBackend ckernel_backend_native ( void  )

Obtain the built-in native backend (single-node CPU, C + intrinsics).

Definition at line 39 of file backend_native.c.

40 {
41  CKMathBackend b;
43  return b;
44 }
static void ckernel_sgemm_native(int M, int N, int K, const float *A, int lda, const float *B, int ldb, const float *bias, float *C, int ldc)
void(* sgemm)(int M, int N, int K, const float *A, int lda, const float *B, int ldb, const float *bias, float *C, int ldc)

References ckernel_sgemm_native(), and CKMathBackend::sgemm.

◆ dequant_q4_0_row()

void dequant_q4_0_row ( const void *  src,
float *  dst,
size_t  n_elements 
)

Dequantize Q4_0 row (multiple blocks)

Parameters
srcQ4_0 data
dstFP32 output
n_elementsNumber of elements to dequantize

Definition at line 61 of file dequant_kernels.c.

62 {
63  const block_q4_0 *blocks = (const block_q4_0 *)src;
64  const size_t n_blocks = n_elements / QK4_0;
65 
66  for (size_t b = 0; b < n_blocks; b++) {
67  dequant_q4_0_block(&blocks[b], &dst[b * QK4_0]);
68  }
69 }
#define QK4_0
Definition: ckernel_quant.h:35
void dequant_q4_0_block(const block_q4_0 *block, float *output)
Dequantize a single Q4_0 block to FP32.

References dequant_q4_0_block(), and QK4_0.

◆ dequant_q4_1_row()

void dequant_q4_1_row ( const void *  src,
float *  dst,
size_t  n_elements 
)

Dequantize Q4_1 row (multiple blocks)

Definition at line 139 of file dequant_kernels.c.

140 {
141  const block_q4_1 *blocks = (const block_q4_1 *)src;
142  const size_t n_blocks = n_elements / QK4_1;
143 
144  for (size_t b = 0; b < n_blocks; b++) {
145  dequant_q4_1_block(&blocks[b], &dst[b * QK4_1]);
146  }
147 }
#define QK4_1
Definition: ckernel_quant.h:50
void dequant_q4_1_block(const block_q4_1 *block, float *output)
Dequantize a single Q4_1 block to FP32.

References dequant_q4_1_block(), and QK4_1.

Referenced by dequant_row().

◆ dequant_q4_k_row()

void dequant_q4_k_row ( const void *  src,
float *  dst,
size_t  n_elements 
)

Dequantize Q4_K row (multiple blocks)

Definition at line 370 of file dequant_kernels.c.

371 {
372  const block_q4_K *blocks = (const block_q4_K *)src;
373  const size_t n_blocks = n_elements / QK_K;
374 
375  for (size_t b = 0; b < n_blocks; b++) {
376  dequant_q4_k_block(&blocks[b], &dst[b * QK_K]);
377  }
378 }
#define QK_K
void dequant_q4_k_block(const block_q4_K *block, float *output)
Dequantize a single Q4_K block to FP32.

References dequant_q4_k_block(), and QK_K.

Referenced by embedding_forward_q4_k().

◆ dequant_q5_0_row()

void dequant_q5_0_row ( const void *  src,
float *  dst,
size_t  n_elements 
)

Dequantize Q5_0 row (multiple blocks)

Definition at line 196 of file dequant_kernels.c.

197 {
198  const block_q5_0 *blocks = (const block_q5_0 *)src;
199  const size_t n_blocks = n_elements / QK5_0;
200 
201  for (size_t b = 0; b < n_blocks; b++) {
202  dequant_q5_0_block(&blocks[b], &dst[b * QK5_0]);
203  }
204 }
#define QK5_0
Definition: ckernel_quant.h:67
void dequant_q5_0_block(const block_q5_0 *block, float *output)
Dequantize a single Q5_0 block to FP32.

References dequant_q5_0_block(), and QK5_0.

◆ dequant_q5_1_row()

void dequant_q5_1_row ( const void *  src,
float *  dst,
size_t  n_elements 
)

Dequantize Q5_1 row (multiple blocks)

Definition at line 255 of file dequant_kernels.c.

256 {
257  const block_q5_1 *blocks = (const block_q5_1 *)src;
258  const size_t n_blocks = n_elements / QK5_1;
259 
260  for (size_t b = 0; b < n_blocks; b++) {
261  dequant_q5_1_block(&blocks[b], &dst[b * QK5_1]);
262  }
263 }
#define QK5_1
Definition: ckernel_quant.h:84
void dequant_q5_1_block(const block_q5_1 *block, float *output)
Dequantize a single Q5_1 block to FP32.

References dequant_q5_1_block(), and QK5_1.

◆ dequant_q6_k_row()

void dequant_q6_k_row ( const void *  src,
float *  dst,
size_t  n_elements 
)

Dequantize Q6_K row (multiple blocks)

Definition at line 420 of file dequant_kernels.c.

421 {
422  const block_q6_K *blocks = (const block_q6_K *)src;
423  const size_t n_blocks = n_elements / QK_K;
424 
425  for (size_t b = 0; b < n_blocks; b++) {
426  dequant_q6_k_block(&blocks[b], &dst[b * QK_K]);
427  }
428 }
void dequant_q6_k_block(const block_q6_K *block, float *output)
Dequantize a single Q6_K block to FP32.

References dequant_q6_k_block(), and QK_K.

Referenced by embedding_forward_q6_k().

◆ dequant_q8_0_row()

void dequant_q8_0_row ( const void *  src,
float *  dst,
size_t  n_elements 
)

Dequantize Q8_0 row (multiple blocks)

Definition at line 286 of file dequant_kernels.c.

287 {
288  const block_q8_0 *blocks = (const block_q8_0 *)src;
289  const size_t n_blocks = n_elements / QK8_0;
290 
291  for (size_t b = 0; b < n_blocks; b++) {
292  dequant_q8_0_block(&blocks[b], &dst[b * QK8_0]);
293  }
294 }
void dequant_q8_0_block(const block_q8_0 *block, float *output)
Dequantize a single Q8_0 block to FP32.

References dequant_q8_0_block(), and QK8_0.

Referenced by dequant_row(), and embedding_forward_q8_0().

◆ embedding_backward()

void embedding_backward ( const int32_t *  token_ids,
int  token_count,
const float *  d_output,
float *  d_token_embeddings,
float *  d_pos_embeddings,
int  vocab_size,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 241 of file embedding_kernels.c.

251 {
252  if (!token_ids || !d_output || !d_token_embeddings) {
253  return;
254  }
255 
256  int tokens = token_count;
257  if (tokens < 0) {
258  tokens = 0;
259  }
260  if (tokens > context_window) {
261  tokens = context_window;
262  }
263 
264  for (int t = 0; t < tokens; ++t) {
265  int id = token_ids[t];
266  if (id < 0 || id >= vocab_size) {
267  id = 0;
268  }
269 
270  const float *d_out = d_output + (size_t)t * (size_t)aligned_embed_dim;
271  float *d_tok = d_token_embeddings + (size_t)id * (size_t)aligned_embed_dim;
272  float *d_pos = d_pos_embeddings ? (d_pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
273 
274  for (int d = 0; d < embed_dim; ++d) {
275  float grad = d_out[d];
276  d_tok[d] += grad;
277  if (add_pos && d_pos) {
278  d_pos[d] += grad;
279  }
280  }
281  }
282 }
int vocab_size
Definition: true_bpe.h:185

References vocab_size.

◆ embedding_backward_bf16()

void embedding_backward_bf16 ( const int32_t *  token_ids,
int  token_count,
const uint16_t *  d_output,
uint16_t *  d_token_embeddings,
uint16_t *  d_pos_embeddings,
int  vocab_size,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 72 of file embedding_kernels_bf16.c.

82 {
83  if (!token_ids || !d_output || !d_token_embeddings) {
84  return;
85  }
86 
87  int tokens = token_count;
88  if (tokens < 0) tokens = 0;
89  if (tokens > context_window) tokens = context_window;
90 
91  for (int t = 0; t < tokens; ++t) {
92  int id = token_ids[t];
93  if (id < 0 || id >= vocab_size) {
94  id = 0;
95  }
96 
97  const uint16_t *d_out = d_output + (size_t)t * (size_t)aligned_embed_dim;
98  uint16_t *d_tok = d_token_embeddings + (size_t)id * (size_t)aligned_embed_dim;
99  uint16_t *d_pos = d_pos_embeddings ? (d_pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
100 
101  for (int d = 0; d < embed_dim; ++d) {
102  float grad = bf16_to_float(d_out[d]);
103 
104  float cur_tok = bf16_to_float(d_tok[d]);
105  d_tok[d] = float_to_bf16(cur_tok + grad);
106 
107  if (add_pos && d_pos) {
108  float cur_pos = bf16_to_float(d_pos[d]);
109  d_pos[d] = float_to_bf16(cur_pos + grad);
110  }
111  }
112  }
113 }

References bf16_to_float(), float_to_bf16(), and vocab_size.

◆ embedding_forward()

void embedding_forward ( const int32_t *  token_ids,
int  token_count,
int  vocab_size,
const float *  token_embeddings,
const float *  pos_embeddings,
float *  output,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 22 of file embedding_kernels.c.

32 {
33  if (!token_ids || !token_embeddings || !output) {
34  return;
35  }
36 
37  int tokens = token_count;
38  if (tokens < 0) {
39  tokens = 0;
40  }
41  if (tokens > context_window) {
42  tokens = context_window;
43  }
44 
45  for (int t = 0; t < tokens; ++t) {
46  int id = token_ids[t];
47  if (id < 0 || id >= vocab_size) {
48  id = 0;
49  }
50 
51  const float *tok = token_embeddings + (size_t)id * (size_t)aligned_embed_dim;
52  const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
53  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
54 
55  if (add_pos && pos) {
56  for (int d = 0; d < embed_dim; ++d) {
57  out[d] = tok[d] + pos[d];
58  }
59  } else {
60  for (int d = 0; d < embed_dim; ++d) {
61  out[d] = tok[d];
62  }
63  }
64 
65  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
66  out[d] = 0.0f;
67  }
68  }
69 
70  for (int t = tokens; t < context_window; ++t) {
71  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
72  memset(out, 0, (size_t)aligned_embed_dim * sizeof(float));
73  }
74 }

References vocab_size.

◆ embedding_forward_bf16()

void embedding_forward_bf16 ( const int32_t *  token_ids,
int  token_count,
int  vocab_size,
const uint16_t *  token_embeddings,
const uint16_t *  pos_embeddings,
uint16_t *  output,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 21 of file embedding_kernels_bf16.c.

31 {
32  if (!token_ids || !token_embeddings || !output) {
33  return;
34  }
35 
36  int tokens = token_count;
37  if (tokens < 0) tokens = 0;
38  if (tokens > context_window) tokens = context_window;
39 
40  for (int t = 0; t < tokens; ++t) {
41  int id = token_ids[t];
42  if (id < 0 || id >= vocab_size) {
43  id = 0;
44  }
45 
46  const uint16_t *tok = token_embeddings + (size_t)id * (size_t)aligned_embed_dim;
47  const uint16_t *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
48  uint16_t *out = output + (size_t)t * (size_t)aligned_embed_dim;
49 
50  if (add_pos && pos) {
51  for (int d = 0; d < embed_dim; ++d) {
52  float v = bf16_to_float(tok[d]) + bf16_to_float(pos[d]);
53  out[d] = float_to_bf16(v);
54  }
55  } else {
56  for (int d = 0; d < embed_dim; ++d) {
57  out[d] = tok[d];
58  }
59  }
60 
61  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
62  out[d] = 0;
63  }
64  }
65 
66  for (int t = tokens; t < context_window; ++t) {
67  uint16_t *out = output + (size_t)t * (size_t)aligned_embed_dim;
68  memset(out, 0, (size_t)aligned_embed_dim * sizeof(uint16_t));
69  }
70 }

References bf16_to_float(), float_to_bf16(), and vocab_size.

◆ embedding_forward_q4_k()

void embedding_forward_q4_k ( const int32_t *  token_ids,
int  token_count,
int  vocab_size,
const void *  token_embeddings,
const float *  pos_embeddings,
float *  output,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 76 of file embedding_kernels.c.

86 {
87  if (!token_ids || !token_embeddings || !output) {
88  return;
89  }
90 
91  int tokens = token_count;
92  if (tokens < 0) {
93  tokens = 0;
94  }
95  if (tokens > context_window) {
96  tokens = context_window;
97  }
98 
99  const size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q4_K, (size_t)aligned_embed_dim);
100  const uint8_t *base = (const uint8_t *)token_embeddings;
101 
102  for (int t = 0; t < tokens; ++t) {
103  int id = token_ids[t];
104  if (id < 0 || id >= vocab_size) {
105  id = 0;
106  }
107 
108  const void *tok = base + (size_t)id * row_bytes;
109  const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
110  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
111 
112  dequant_q4_k_row(tok, out, (size_t)aligned_embed_dim);
113 
114  if (add_pos && pos) {
115  for (int d = 0; d < embed_dim; ++d) {
116  out[d] += pos[d];
117  }
118  }
119 
120  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
121  out[d] = 0.0f;
122  }
123  }
124 
125  for (int t = tokens; t < context_window; ++t) {
126  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
127  memset(out, 0, (size_t)aligned_embed_dim * sizeof(float));
128  }
129 }
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void dequant_q4_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q4_K row (multiple blocks)

References CK_DT_Q4_K, ck_dtype_row_bytes(), dequant_q4_k_row(), and vocab_size.

Referenced by model_decode_token(), model_forward_prefill_impl(), qwen2_0_5b_decode_decode_token(), and qwen2_0_5b_decode_forward_prefill_impl().

◆ embedding_forward_q6_k()

void embedding_forward_q6_k ( const int32_t *  token_ids,
int  token_count,
int  vocab_size,
const void *  token_embeddings,
const float *  pos_embeddings,
float *  output,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 186 of file embedding_kernels.c.

196 {
197  if (!token_ids || !token_embeddings || !output) {
198  return;
199  }
200 
201  int tokens = token_count;
202  if (tokens < 0) {
203  tokens = 0;
204  }
205  if (tokens > context_window) {
206  tokens = context_window;
207  }
208 
209  const size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q6_K, (size_t)aligned_embed_dim);
210  const uint8_t *base = (const uint8_t *)token_embeddings;
211 
212  for (int t = 0; t < tokens; ++t) {
213  int id = token_ids[t];
214  if (id < 0 || id >= vocab_size) {
215  id = 0;
216  }
217 
218  const void *tok = base + (size_t)id * row_bytes;
219  const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
220  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
221 
222  dequant_q6_k_row(tok, out, (size_t)aligned_embed_dim);
223 
224  if (add_pos && pos) {
225  for (int d = 0; d < embed_dim; ++d) {
226  out[d] += pos[d];
227  }
228  }
229 
230  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
231  out[d] = 0.0f;
232  }
233  }
234 
235  for (int t = tokens; t < context_window; ++t) {
236  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
237  memset(out, 0, (size_t)aligned_embed_dim * sizeof(float));
238  }
239 }
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
void dequant_q6_k_row(const void *src, float *dst, size_t n_elements)
Dequantize Q6_K row (multiple blocks)

References CK_DT_Q6_K, ck_dtype_row_bytes(), dequant_q6_k_row(), and vocab_size.

◆ embedding_forward_q8_0()

void embedding_forward_q8_0 ( const int32_t *  token_ids,
int  token_count,
int  vocab_size,
const void *  token_embeddings,
const float *  pos_embeddings,
float *  output,
int  embed_dim,
int  aligned_embed_dim,
int  context_window,
int  add_pos 
)

Definition at line 131 of file embedding_kernels.c.

141 {
142  if (!token_ids || !token_embeddings || !output) {
143  return;
144  }
145 
146  int tokens = token_count;
147  if (tokens < 0) {
148  tokens = 0;
149  }
150  if (tokens > context_window) {
151  tokens = context_window;
152  }
153 
154  const size_t row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0, (size_t)aligned_embed_dim);
155  const uint8_t *base = (const uint8_t *)token_embeddings;
156 
157  for (int t = 0; t < tokens; ++t) {
158  int id = token_ids[t];
159  if (id < 0 || id >= vocab_size) {
160  id = 0;
161  }
162 
163  const void *tok = base + (size_t)id * row_bytes;
164  const float *pos = pos_embeddings ? (pos_embeddings + (size_t)t * (size_t)aligned_embed_dim) : NULL;
165  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
166 
167  dequant_q8_0_row(tok, out, (size_t)aligned_embed_dim);
168 
169  if (add_pos && pos) {
170  for (int d = 0; d < embed_dim; ++d) {
171  out[d] += pos[d];
172  }
173  }
174 
175  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
176  out[d] = 0.0f;
177  }
178  }
179 
180  for (int t = tokens; t < context_window; ++t) {
181  float *out = output + (size_t)t * (size_t)aligned_embed_dim;
182  memset(out, 0, (size_t)aligned_embed_dim * sizeof(float));
183  }
184 }
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
void dequant_q8_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q8_0 row (multiple blocks)

References CK_DT_Q8_0, ck_dtype_row_bytes(), dequant_q8_0_row(), and vocab_size.

Referenced by qwen2_0_5b_decode_decode_token(), and qwen2_0_5b_decode_forward_prefill_impl().

◆ fc1_backward_kernel()

void fc1_backward_kernel ( const float *  d_output,
const float *  fc1_input,
const float *  W_fc1,
float *  d_input,
float *  d_W_fc1,
float *  d_b_fc1,
int  T,
int  aligned_in,
int  aligned_out,
int  num_threads 
)

Definition at line 167 of file mlp_kernels.c.

177 {
178  (void)num_threads; // Threading handled by GEMM kernels
179 
180  // 1. d_input[T, in] = d_output[T, out] @ W[out, in]
181  // Using gemm_nn: C[M,N] = A[M,K] @ B[K,N]
182  // A = d_output [T, out], B = W [out, in], C = d_input [T, in]
183  // M = T, N = aligned_in, K = aligned_out
184  gemm_nn_avx512(d_output, W_fc1, NULL, d_input,
185  T, aligned_in, aligned_out);
186 
187  // 2. d_W[out, in] = d_output[T, out].T @ fc1_input[T, in]
188  // Using gemm_tn: C[M,N] = A[K,M].T @ B[K,N]
189  // A = d_output [T, out] (stored as [K=T, M=out]), B = fc1_input [T, in]
190  // C = d_W [out, in], M = aligned_out, N = aligned_in, K = T
191  gemm_tn_avx512(d_output, fc1_input, NULL, d_W_fc1,
192  aligned_out, aligned_in, T);
193 
194  // 3. d_b_fc1 = sum_over_T(d_output)
195 #pragma omp parallel for schedule(static)
196  for (int out_idx = 0; out_idx < aligned_out; ++out_idx) {
197  float bias_grad = 0.0f;
198  for (int t = 0; t < T; ++t) {
199  bias_grad += d_output[(size_t)t * aligned_out + out_idx];
200  }
201  d_b_fc1[out_idx] += bias_grad;
202  }
203 }
void gemm_nn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:339
void gemm_tn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:521

References gemm_nn_avx512(), and gemm_tn_avx512().

Referenced by ck_layer_backward_rmsnorm_swiglu().

◆ fc2_backward_kernel()

void fc2_backward_kernel ( const float *  d_output,
const float *  fc2_input,
const float *  W_fc2,
float *  d_input,
float *  d_W_fc2,
float *  d_b_fc2,
int  T,
int  aligned_in,
int  aligned_out,
int  num_threads 
)

Definition at line 118 of file mlp_kernels.c.

128 {
129  (void)num_threads; // Threading handled by GEMM kernels
130 
131  // 1. d_input[T, in] = d_output[T, out] @ W[out, in]
132  // Using gemm_nn: C[M,N] = A[M,K] @ B[K,N]
133  // A = d_output [T, out], B = W [out, in], C = d_input [T, in]
134  // M = T, N = aligned_in, K = aligned_out
135  gemm_nn_avx512(d_output, W_fc2, NULL, d_input,
136  T, aligned_in, aligned_out);
137 
138  // 2. d_W[out, in] = d_output[T, out].T @ fc2_input[T, in]
139  // Using gemm_tn: C[M,N] = A[K,M].T @ B[K,N]
140  // A = d_output [T, out] (stored as [K=T, M=out]), B = fc2_input [T, in]
141  // C = d_W [out, in], M = aligned_out, N = aligned_in, K = T
142  // Note: gemm_tn overwrites, so we need to save and add if accumulating
143  // For now, assume d_W starts zeroed (gradient accumulation handled at higher level)
144  gemm_tn_avx512(d_output, fc2_input, NULL, d_W_fc2,
145  aligned_out, aligned_in, T);
146 
147  // 3. d_b_fc2 = sum_over_T(d_output)
148 #pragma omp parallel for schedule(static)
149  for (int out_idx = 0; out_idx < aligned_out; ++out_idx) {
150  float bias_grad = 0.0f;
151  for (int t = 0; t < T; ++t) {
152  bias_grad += d_output[(size_t)t * aligned_out + out_idx];
153  }
154  d_b_fc2[out_idx] += bias_grad;
155  }
156 }

References gemm_nn_avx512(), and gemm_tn_avx512().

Referenced by ck_attention_project_head_major_backward(), ck_layer_backward_rmsnorm_swiglu(), and ck_qkv_project_head_major_backward().

◆ fused_mlp_swiglu_decode()

void fused_mlp_swiglu_decode ( const float *  x,
const float *  W_gate,
const float *  W_up,
const float *  W_down,
const float *  b_gate,
const float *  b_up,
const float *  b_down,
float *  output,
int  D,
int  Hff 
)

Definition at line 154 of file mlp_fused_decode.c.

165 {
166 #if defined(__AVX512F__)
167  // Initialize output with bias or zero
168  if (b_down) {
169  memcpy(output, b_down, D * sizeof(float));
170  } else {
171  memset(output, 0, D * sizeof(float));
172  }
173 
174  // Process intermediate dimension in tiles
175  // Each tile computes MLP_TILE_SIZE swiglu values and immediately
176  // accumulates them into the output
177 
178  /* Bounds check for stack allocation */
179  if (D > 4096) return;
180 
181  #pragma omp parallel
182  {
183  /* Thread-local accumulator on stack (no malloc!) */
184  float local_output[4096] __attribute__((aligned(64)));
185  memset(local_output, 0, D * sizeof(float));
186 
187  #pragma omp for schedule(static)
188  for (int t = 0; t < Hff; t += MLP_TILE_SIZE) {
189  int tile_end = (t + MLP_TILE_SIZE < Hff) ? t + MLP_TILE_SIZE : Hff;
190  int tile_size = tile_end - t;
191 
192  // Compute SwiGLU for this tile (stays in L1 cache)
193  float swiglu_tile[MLP_TILE_SIZE] __attribute__((aligned(64)));
194 
195  for (int j = t; j < tile_end; j++) {
196  const float *wg_row = &W_gate[j * D];
197  const float *wu_row = &W_up[j * D];
198 
199  // Compute gate = x @ W_gate[j] using AVX-512
200  __m512 gate_acc = _mm512_setzero_ps();
201  __m512 up_acc = _mm512_setzero_ps();
202 
203  int k = 0;
204  for (; k <= D - 16; k += 16) {
205  __m512 x_vec = _mm512_loadu_ps(&x[k]);
206  __m512 wg_vec = _mm512_loadu_ps(&wg_row[k]);
207  __m512 wu_vec = _mm512_loadu_ps(&wu_row[k]);
208 
209  gate_acc = _mm512_fmadd_ps(x_vec, wg_vec, gate_acc);
210  up_acc = _mm512_fmadd_ps(x_vec, wu_vec, up_acc);
211  }
212 
213  float gate = hsum512_ps(gate_acc);
214  float up = hsum512_ps(up_acc);
215 
216  // Scalar remainder
217  for (; k < D; k++) {
218  gate += x[k] * wg_row[k];
219  up += x[k] * wu_row[k];
220  }
221 
222  // Add biases
223  if (b_gate) gate += b_gate[j];
224  if (b_up) up += b_up[j];
225 
226  // SwiGLU: SiLU(gate) * up
227  swiglu_tile[j - t] = silu_scalar(gate) * up;
228  }
229 
230  // Accumulate into output via W_down
231  // output[i] += sum_j(swiglu_tile[j] * W_down[i, t+j])
232  for (int i = 0; i < D; i++) {
233  const float *wd_row = &W_down[i * Hff + t];
234 
235  __m512 acc = _mm512_setzero_ps();
236  int j = 0;
237  for (; j <= tile_size - 16; j += 16) {
238  __m512 sw_vec = _mm512_loadu_ps(&swiglu_tile[j]);
239  __m512 wd_vec = _mm512_loadu_ps(&wd_row[j]);
240  acc = _mm512_fmadd_ps(sw_vec, wd_vec, acc);
241  }
242 
243  float sum = hsum512_ps(acc);
244  for (; j < tile_size; j++) {
245  sum += swiglu_tile[j] * wd_row[j];
246  }
247 
248  local_output[i] += sum;
249  }
250  }
251 
252  // Reduce thread-local outputs
253  #pragma omp critical
254  {
255  for (int i = 0; i < D; i++) {
256  output[i] += local_output[i];
257  }
258  }
259  /* No free - stack buffer auto-deallocates */
260  }
261 
262 #else
263  // Scalar fallback (same algorithm, no SIMD)
264  if (b_down) {
265  memcpy(output, b_down, D * sizeof(float));
266  } else {
267  memset(output, 0, D * sizeof(float));
268  }
269 
270  for (int t = 0; t < Hff; t += MLP_TILE_SIZE) {
271  int tile_end = (t + MLP_TILE_SIZE < Hff) ? t + MLP_TILE_SIZE : Hff;
272  int tile_size = tile_end - t;
273 
274  float swiglu_tile[MLP_TILE_SIZE];
275 
276  for (int j = t; j < tile_end; j++) {
277  float gate = 0.0f;
278  float up = 0.0f;
279 
280  for (int k = 0; k < D; k++) {
281  gate += x[k] * W_gate[j * D + k];
282  up += x[k] * W_up[j * D + k];
283  }
284 
285  if (b_gate) gate += b_gate[j];
286  if (b_up) up += b_up[j];
287 
288  swiglu_tile[j - t] = silu_scalar(gate) * up;
289  }
290 
291  for (int i = 0; i < D; i++) {
292  for (int j = 0; j < tile_size; j++) {
293  output[i] += swiglu_tile[j] * W_down[i * Hff + t + j];
294  }
295  }
296  }
297 #endif
298 }
#define MLP_TILE_SIZE
static float silu_scalar(float x)
__attribute__((visibility("default"))) CKTokenizer *ck_tokenizer_create(CKTokenizerType type)

References __attribute__(), MLP_TILE_SIZE, and silu_scalar().

◆ fused_mlp_swiglu_decode_tiled()

void fused_mlp_swiglu_decode_tiled ( const float *  x,
const float *  W_gate,
const float *  W_up,
const float *  W_down,
const float *  b_gate,
const float *  b_up,
const float *  b_down,
float *  output,
int  D,
int  Hff 
)

Definition at line 429 of file mlp_fused_decode.c.

440 {
441  // Tile size chosen to fit in L2 with W_down tile
442  // Tile of swiglu: 256 floats = 1KB
443  // Tile of W_down: 256 * D floats = 256 * 896 * 4 = 896KB
444  // Fits in 2MB L2 with room for x and prefetch
445  const int TILE = 256;
446 
447 #if defined(__AVX512F__)
448  // Initialize output
449  #pragma omp parallel for schedule(static)
450  for (int i = 0; i < D; i++) {
451  output[i] = b_down ? b_down[i] : 0.0f;
452  }
453 
454  // Process tiles of intermediate dimension
455  for (int t = 0; t < Hff; t += TILE) {
456  int tile_end = (t + TILE < Hff) ? t + TILE : Hff;
457  int tile_size = tile_end - t;
458 
459  // Compute swiglu tile
460  float swiglu_tile[256] __attribute__((aligned(64)));
461 
462  #pragma omp parallel for schedule(static)
463  for (int jj = 0; jj < tile_size; jj++) {
464  int j = t + jj;
465  const float *wg_row = &W_gate[j * D];
466  const float *wu_row = &W_up[j * D];
467 
468  __m512 gate_acc = _mm512_setzero_ps();
469  __m512 up_acc = _mm512_setzero_ps();
470 
471  int k = 0;
472  for (; k <= D - 16; k += 16) {
473  __m512 x_vec = _mm512_loadu_ps(&x[k]);
474  __m512 wg_vec = _mm512_loadu_ps(&wg_row[k]);
475  __m512 wu_vec = _mm512_loadu_ps(&wu_row[k]);
476 
477  gate_acc = _mm512_fmadd_ps(x_vec, wg_vec, gate_acc);
478  up_acc = _mm512_fmadd_ps(x_vec, wu_vec, up_acc);
479  }
480 
481  float gate = hsum512_ps(gate_acc);
482  float up = hsum512_ps(up_acc);
483 
484  for (; k < D; k++) {
485  gate += x[k] * wg_row[k];
486  up += x[k] * wu_row[k];
487  }
488 
489  if (b_gate) gate += b_gate[j];
490  if (b_up) up += b_up[j];
491 
492  swiglu_tile[jj] = silu_scalar(gate) * up;
493  }
494 
495  // Accumulate into output (parallelize over D)
496  #pragma omp parallel for schedule(static)
497  for (int i = 0; i < D; i++) {
498  const float *wd_row = &W_down[i * Hff + t];
499 
500  __m512 acc = _mm512_setzero_ps();
501  int j = 0;
502  for (; j <= tile_size - 16; j += 16) {
503  __m512 sw_vec = _mm512_loadu_ps(&swiglu_tile[j]);
504  __m512 wd_vec = _mm512_loadu_ps(&wd_row[j]);
505  acc = _mm512_fmadd_ps(sw_vec, wd_vec, acc);
506  }
507 
508  float sum = hsum512_ps(acc);
509  for (; j < tile_size; j++) {
510  sum += swiglu_tile[j] * wd_row[j];
511  }
512 
513  // Atomic add (or use thread-local buffers for better perf)
514  #pragma omp atomic
515  output[i] += sum;
516  }
517  }
518 
519 #else
520  // Scalar fallback
521  for (int i = 0; i < D; i++) {
522  output[i] = b_down ? b_down[i] : 0.0f;
523  }
524 
525  for (int t = 0; t < Hff; t += TILE) {
526  int tile_end = (t + TILE < Hff) ? t + TILE : Hff;
527 
528  float swiglu_tile[256];
529 
530  for (int j = t; j < tile_end; j++) {
531  float gate = 0.0f, up = 0.0f;
532  for (int k = 0; k < D; k++) {
533  gate += x[k] * W_gate[j * D + k];
534  up += x[k] * W_up[j * D + k];
535  }
536  if (b_gate) gate += b_gate[j];
537  if (b_up) up += b_up[j];
538  swiglu_tile[j - t] = silu_scalar(gate) * up;
539  }
540 
541  for (int i = 0; i < D; i++) {
542  for (int j = t; j < tile_end; j++) {
543  output[i] += swiglu_tile[j - t] * W_down[i * Hff + j];
544  }
545  }
546  }
547 #endif
548 }

References __attribute__(), and silu_scalar().

Referenced by fused_mlp_swiglu_decode_v2().

◆ fused_mlp_swiglu_decode_v2()

void fused_mlp_swiglu_decode_v2 ( const float *  x,
const float *  W_gate,
const float *  W_up,
const float *  W_down,
const float *  b_gate,
const float *  b_up,
const float *  b_down,
float *  output,
int  D,
int  Hff 
)

Definition at line 318 of file mlp_fused_decode.c.

329 {
330  // For large Hff, use tiled version to avoid stack overflow
331  if (Hff > MAX_SWIGLU_STACK) {
332  fused_mlp_swiglu_decode_tiled(x, W_gate, W_up, W_down,
333  b_gate, b_up, b_down, output, D, Hff);
334  return;
335  }
336 
337 #if defined(__AVX512F__)
338  // Stack-allocated swiglu buffer (max 32KB)
339  float swiglu[MAX_SWIGLU_STACK] __attribute__((aligned(64)));
340 
341  // Phase 1: Compute all swiglu values (parallelize over Hff)
342  #pragma omp parallel for schedule(static)
343  for (int j = 0; j < Hff; j++) {
344  const float *wg_row = &W_gate[j * D];
345  const float *wu_row = &W_up[j * D];
346 
347  __m512 gate_acc = _mm512_setzero_ps();
348  __m512 up_acc = _mm512_setzero_ps();
349 
350  int k = 0;
351  for (; k <= D - 16; k += 16) {
352  __m512 x_vec = _mm512_loadu_ps(&x[k]);
353  __m512 wg_vec = _mm512_loadu_ps(&wg_row[k]);
354  __m512 wu_vec = _mm512_loadu_ps(&wu_row[k]);
355 
356  gate_acc = _mm512_fmadd_ps(x_vec, wg_vec, gate_acc);
357  up_acc = _mm512_fmadd_ps(x_vec, wu_vec, up_acc);
358  }
359 
360  float gate = hsum512_ps(gate_acc);
361  float up = hsum512_ps(up_acc);
362 
363  for (; k < D; k++) {
364  gate += x[k] * wg_row[k];
365  up += x[k] * wu_row[k];
366  }
367 
368  if (b_gate) gate += b_gate[j];
369  if (b_up) up += b_up[j];
370 
371  swiglu[j] = silu_scalar(gate) * up;
372  }
373 
374  // Phase 2: Down projection (parallelize over D)
375  #pragma omp parallel for schedule(static)
376  for (int i = 0; i < D; i++) {
377  const float *wd_row = &W_down[i * Hff];
378 
379  __m512 acc = _mm512_setzero_ps();
380  int j = 0;
381  for (; j <= Hff - 16; j += 16) {
382  __m512 sw_vec = _mm512_loadu_ps(&swiglu[j]);
383  __m512 wd_vec = _mm512_loadu_ps(&wd_row[j]);
384  acc = _mm512_fmadd_ps(sw_vec, wd_vec, acc);
385  }
386 
387  float sum = hsum512_ps(acc);
388  for (; j < Hff; j++) {
389  sum += swiglu[j] * wd_row[j];
390  }
391 
392  output[i] = sum + (b_down ? b_down[i] : 0.0f);
393  }
394 
395 #else
396  // Scalar fallback with stack buffer
397  float swiglu[MAX_SWIGLU_STACK];
398 
399  for (int j = 0; j < Hff; j++) {
400  float gate = 0.0f, up = 0.0f;
401  for (int k = 0; k < D; k++) {
402  gate += x[k] * W_gate[j * D + k];
403  up += x[k] * W_up[j * D + k];
404  }
405  if (b_gate) gate += b_gate[j];
406  if (b_up) up += b_up[j];
407  swiglu[j] = silu_scalar(gate) * up;
408  }
409 
410  for (int i = 0; i < D; i++) {
411  float sum = 0.0f;
412  for (int j = 0; j < Hff; j++) {
413  sum += swiglu[j] * W_down[i * Hff + j];
414  }
415  output[i] = sum + (b_down ? b_down[i] : 0.0f);
416  }
417 #endif
418 }
void fused_mlp_swiglu_decode_tiled(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *b_gate, const float *b_up, const float *b_down, float *output, int D, int Hff)
#define MAX_SWIGLU_STACK

References __attribute__(), fused_mlp_swiglu_decode_tiled(), MAX_SWIGLU_STACK, and silu_scalar().

Referenced by ck_mlp_swiglu_forward_fully_fused_token().

◆ fused_mlp_swiglu_prefill()

void fused_mlp_swiglu_prefill ( const float *  x,
const float *  W_gate,
const float *  W_up,
const float *  W_down,
float *  output,
int  seq_len,
int  hidden,
int  intermediate,
float *  scratch 
)

Fused MLP (Gate + Up + SwiGLU + Down) for prefill.

Tiles along token dimension to keep gate/up/hidden in L3 cache.

Parameters
scratchTemporary buffer from fused_mlp_swiglu_scratch_size()

Definition at line 879 of file prefill_fused_gemm.c.

889 {
890  fused_mlp_swiglu_prefill_bias(x, W_gate, W_up, W_down,
891  NULL, NULL, NULL,
892  output, seq_len, hidden, intermediate,
893  scratch);
894 }
void fused_mlp_swiglu_prefill_bias(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *B_gate, const float *B_up, const float *B_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
Fused MLP for prefill with proper tiling.

References fused_mlp_swiglu_prefill_bias().

◆ fused_mlp_swiglu_prefill_bias()

void fused_mlp_swiglu_prefill_bias ( const float *  x,
const float *  W_gate,
const float *  W_up,
const float *  W_down,
const float *  B_gate,
const float *  B_up,
const float *  B_down,
float *  output,
int  seq_len,
int  hidden,
int  intermediate,
float *  scratch 
)

Fused MLP (Gate + Up + SwiGLU + Down) for prefill with biases.

Fused MLP (Gate + Up + SwiGLU + Down) for prefill with biases.

Definition at line 746 of file prefill_fused_gemm.c.

759 {
760  /* MLP is more complex because we have:
761  * gate = x @ W_gate
762  * up = x @ W_up
763  * hidden = silu(gate) * up
764  * out = hidden @ W_down
765  *
766  * The intermediate (gate, up, hidden) is large: seq_len × intermediate
767  * For Qwen2-0.5B: 1024 × 4864 × 4 = 19.4MB (way bigger than L3!)
768  *
769  * Strategy: Tile along intermediate dimension for gate/up,
770  * then fuse SwiGLU, then tile down projection.
771  */
772 
773  /* scratch layout:
774  * [gate_tile: TILE_M × TILE_N_INTER]
775  * [up_tile: TILE_M × TILE_N_INTER]
776  */
777  const int TILE_N_INTER = 512; /* Intermediate tile size */
778  float *gate_tile = scratch;
779  float *up_tile = scratch + (size_t)PREFILL_TILE_M * TILE_N_INTER;
780  float *hidden_tile = gate_tile; /* Reuse gate_tile for hidden after SwiGLU */
781 
782  /* For each chunk of intermediate dimension */
783  for (int inter_start = 0; inter_start < intermediate; inter_start += TILE_N_INTER) {
784  int tile_inter = (inter_start + TILE_N_INTER <= intermediate)
785  ? TILE_N_INTER : (intermediate - inter_start);
786 
787  const float *W_gate_tile = W_gate + (size_t)inter_start * hidden;
788  const float *W_up_tile = W_up + (size_t)inter_start * hidden;
789 
790  /* For each chunk of tokens */
791  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
792  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
793  ? PREFILL_TILE_M : (seq_len - m_start);
794 
795  const float *x_tile = x + (size_t)m_start * hidden;
796 
797  /* Compute gate and up projections for this tile */
798  gemm_tile_nt_strided(x_tile, W_gate_tile, gate_tile,
799  tile_m, tile_inter, hidden, tile_inter);
800  gemm_tile_nt_strided(x_tile, W_up_tile, up_tile,
801  tile_m, tile_inter, hidden, tile_inter);
802  if (B_gate) {
803  add_bias_tile(gate_tile, B_gate + inter_start, tile_m, tile_inter);
804  }
805  if (B_up) {
806  add_bias_tile(up_tile, B_up + inter_start, tile_m, tile_inter);
807  }
808 
809  /* Fused SwiGLU: hidden = silu(gate) * up */
810  for (int i = 0; i < tile_m; ++i) {
811  float *g = gate_tile + (size_t)i * tile_inter;
812  float *u = up_tile + (size_t)i * tile_inter;
813  for (int j = 0; j < tile_inter; ++j) {
814  float gv = g[j];
815  float silu = gv / (1.0f + expf(-gv));
816  g[j] = silu * u[j]; /* hidden_tile = gate_tile */
817  }
818  }
819 
820  /* Down projection: accumulate into output
821  * out[m_start:, :] += hidden_tile @ W_down[inter_start:, :]^T
822  */
823  const float *W_down_slice = W_down + (size_t)inter_start; /* Column slice */
824  float *out_tile = output + (size_t)m_start * hidden;
825 
826  /* This is trickier - W_down is [hidden × intermediate]
827  * We have hidden_tile[tile_m × tile_inter]
828  * We want out[tile_m × hidden] += hidden_tile × W_down[:, inter_start:inter_start+tile_inter]^T
829  *
830  * For proper accumulation, need to handle this carefully.
831  * For now, use a simpler approach: accumulate partial results.
832  */
833  for (int i = 0; i < tile_m; ++i) {
834  float *h = hidden_tile + (size_t)i * tile_inter;
835  float *o = out_tile + (size_t)i * hidden;
836 
837  for (int d = 0; d < hidden; ++d) {
838  const float *w_row = W_down + (size_t)d * intermediate + inter_start;
839  float sum = (inter_start == 0)
840  ? (B_down ? B_down[d] : 0.0f)
841  : o[d];
842 
843 #if defined(__AVX512F__)
844  __m512 acc = _mm512_setzero_ps();
845  int j = 0;
846  for (; j + 16 <= tile_inter; j += 16) {
847  __m512 hv = _mm512_loadu_ps(h + j);
848  __m512 wv = _mm512_loadu_ps(w_row + j);
849  acc = _mm512_fmadd_ps(hv, wv, acc);
850  }
851  sum += _mm512_reduce_add_ps(acc);
852  for (; j < tile_inter; ++j) {
853  sum += h[j] * w_row[j];
854  }
855 #elif defined(__AVX__)
856  __m256 acc = _mm256_setzero_ps();
857  int j = 0;
858  for (; j + 8 <= tile_inter; j += 8) {
859  __m256 hv = _mm256_loadu_ps(h + j);
860  __m256 wv = _mm256_loadu_ps(w_row + j);
861  acc = _mm256_add_ps(acc, _mm256_mul_ps(hv, wv));
862  }
863  sum += hsum256_prefill(acc);
864  for (; j < tile_inter; ++j) {
865  sum += h[j] * w_row[j];
866  }
867 #else
868  for (int j = 0; j < tile_inter; ++j) {
869  sum += h[j] * w_row[j];
870  }
871 #endif
872  o[d] = sum;
873  }
874  }
875  }
876  }
877 }
#define PREFILL_TILE_M
static void add_bias_tile(float *out, const float *bias, int tile_m, int out_dim)
static void gemm_tile_nt_strided(const float *A, const float *B_tile, float *C, int tile_m, int tile_n, int K, int C_stride)
GEMM tile with N-dimension tiling (weight reuse)
static void silu(float *x, int n)
Definition: v6_simple.c:159

References add_bias_tile(), gemm_tile_nt_strided(), PREFILL_TILE_M, and silu().

Referenced by fused_mlp_swiglu_prefill().

◆ fused_mlp_swiglu_prefill_w1w2_quant()

void fused_mlp_swiglu_prefill_w1w2_quant ( const float *  x,
const void *  W1,
const float *  B1,
CKDataType  w1_dt,
const void *  W2,
const float *  B2,
CKDataType  w2_dt,
float *  output,
int  seq_len,
int  embed_dim,
int  aligned_embed_dim,
int  intermediate_dim,
int  aligned_intermediate_dim,
void *  scratch 
)

Quantized fused MLP for prefill (W1=gate+up, W2=down)

W1 uses Q8_0 activations (Q5_0/Q8_0 weights), W2 uses Q8_K activations (Q4_K/Q6_K weights).

Uses Q8_0 activations for W1 (Q5_0/Q8_0 weights) and Q8_K activations for W2 (Q4_K/Q6_K weights).

Definition at line 965 of file prefill_fused_gemm.c.

980 {
981  if (!x || !W1 || !W2 || !output || !scratch) {
982  return;
983  }
984  if (seq_len <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
985  intermediate_dim <= 0 || aligned_intermediate_dim <= 0) {
986  return;
987  }
988  if (aligned_embed_dim < embed_dim || aligned_intermediate_dim < intermediate_dim) {
989  return;
990  }
991  if ((aligned_embed_dim % 32) != 0 || (aligned_intermediate_dim % 256) != 0) {
992  return;
993  }
994  if (!mlp_q8_0_dtype_supported(w1_dt) || !mlp_q8_k_dtype_supported(w2_dt)) {
995  return;
996  }
997 
998  const int tile_m_max = PREFILL_TILE_M;
999  const int inter = aligned_intermediate_dim;
1000  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0, (size_t)aligned_embed_dim);
1001  const size_t q8k_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_K, (size_t)aligned_intermediate_dim);
1002  const size_t w1_row_bytes = ck_dtype_row_bytes(w1_dt, (size_t)aligned_embed_dim);
1003 
1004  uint8_t *scratch_bytes = (uint8_t *)scratch;
1005  size_t q8_bytes = (size_t)tile_m_max * q8_row_bytes;
1006  size_t gate_bytes = (size_t)tile_m_max * (size_t)inter * sizeof(float);
1007  size_t up_bytes = gate_bytes;
1008  size_t gate_offset = align_up_size(q8_bytes, 64);
1009  size_t up_offset = gate_offset + align_up_size(gate_bytes, 64);
1010  size_t q8k_offset = up_offset + align_up_size(up_bytes, 64);
1011 
1012  uint8_t *q8_tile = scratch_bytes;
1013  float *gate_tile = (float *)(scratch_bytes + gate_offset);
1014  float *up_tile = (float *)(scratch_bytes + up_offset);
1015  uint8_t *q8k_tile = scratch_bytes + q8k_offset;
1016 
1017  const uint8_t *w1_base = (const uint8_t *)W1;
1018  const uint8_t *w_gate = w1_base;
1019  const uint8_t *w_up = w1_base + (size_t)inter * w1_row_bytes;
1020 
1021  const float *b_gate = B1;
1022  const float *b_up = B1 ? (B1 + (size_t)inter) : NULL;
1023 
1024  for (int m_start = 0; m_start < seq_len; m_start += tile_m_max) {
1025  int tile_m = (m_start + tile_m_max <= seq_len)
1026  ? tile_m_max : (seq_len - m_start);
1027 
1028  const float *x_tile = x + (size_t)m_start * (size_t)aligned_embed_dim;
1029  float *out_tile = output + (size_t)m_start * (size_t)aligned_embed_dim;
1030 
1031  for (int t = 0; t < tile_m; ++t) {
1032  const float *row = x_tile + (size_t)t * (size_t)aligned_embed_dim;
1033  quantize_row_q8_0(row,
1034  q8_tile + (size_t)t * q8_row_bytes,
1035  aligned_embed_dim);
1036  }
1037 
1038  gemm_nt_q8_0_mlp_dispatch(q8_tile, w_gate, b_gate, gate_tile,
1039  tile_m, inter, aligned_embed_dim, w1_dt);
1040  gemm_nt_q8_0_mlp_dispatch(q8_tile, w_up, b_up, up_tile,
1041  tile_m, inter, aligned_embed_dim, w1_dt);
1042 
1043  for (int i = 0; i < tile_m; ++i) {
1044  float *g = gate_tile + (size_t)i * (size_t)inter;
1045  float *u = up_tile + (size_t)i * (size_t)inter;
1046  for (int j = 0; j < inter; ++j) {
1047  g[j] = silu_prefill(g[j]) * u[j];
1048  }
1049  }
1050 
1051  for (int i = 0; i < tile_m; ++i) {
1052  const float *row = gate_tile + (size_t)i * (size_t)inter;
1053  quantize_row_q8_k(row,
1054  q8k_tile + (size_t)i * q8k_row_bytes,
1055  aligned_intermediate_dim);
1056  }
1057 
1058  gemm_nt_q8_k_mlp_dispatch(q8k_tile, W2, B2, out_tile,
1059  tile_m, aligned_embed_dim, aligned_intermediate_dim, w2_dt);
1060  }
1061 }
@ CK_DT_Q8_K
Definition: ckernel_dtype.h:43
void quantize_row_q8_k(const float *x, void *y, int k)
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)
static void gemm_nt_q8_0_mlp_dispatch(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
static size_t align_up_size(size_t value, size_t align)
static int mlp_q8_k_dtype_supported(CKDataType dt)
static float silu_prefill(float x)
static int mlp_q8_0_dtype_supported(CKDataType dt)
static void gemm_nt_q8_k_mlp_dispatch(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)

References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), gemm_nt_q8_0_mlp_dispatch(), gemm_nt_q8_k_mlp_dispatch(), mlp_q8_0_dtype_supported(), mlp_q8_k_dtype_supported(), PREFILL_TILE_M, quantize_row_q8_0(), quantize_row_q8_k(), and silu_prefill().

Referenced by mega_fused_outproj_mlp_prefill().

◆ fused_mlp_swiglu_prefill_w1w2_quant_scratch_size()

size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size ( int  aligned_embed_dim,
int  aligned_intermediate_dim 
)

Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.

Definition at line 1063 of file prefill_fused_gemm.c.

1065 {
1066  if (aligned_embed_dim <= 0 || aligned_intermediate_dim <= 0) {
1067  return 0;
1068  }
1069  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0, (size_t)aligned_embed_dim);
1070  const size_t q8k_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_K, (size_t)aligned_intermediate_dim);
1071  const size_t q8_bytes = (size_t)PREFILL_TILE_M * q8_row_bytes;
1072  const size_t gate_bytes = (size_t)PREFILL_TILE_M * (size_t)aligned_intermediate_dim * sizeof(float);
1073  const size_t up_bytes = gate_bytes;
1074  const size_t q8k_bytes = (size_t)PREFILL_TILE_M * q8k_row_bytes;
1075 
1076  return align_up_size(q8_bytes, 64) +
1077  align_up_size(gate_bytes, 64) +
1078  align_up_size(up_bytes, 64) +
1079  align_up_size(q8k_bytes, 64);
1080 }

References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), and PREFILL_TILE_M.

Referenced by mega_fused_outproj_mlp_prefill_scratch_size().

◆ fused_mlp_swiglu_scratch_size()

size_t fused_mlp_swiglu_scratch_size ( int  intermediate)

Get scratch buffer size for fused_mlp_swiglu_prefill.

Get scratch buffer size for fused_mlp_swiglu_prefill.

Definition at line 899 of file prefill_fused_gemm.c.

899  {
900  const int TILE_N_INTER = 512;
901  /* gate_tile + up_tile */
902  return 2 * (size_t)PREFILL_TILE_M * TILE_N_INTER * sizeof(float);
903 }

References PREFILL_TILE_M.

◆ fused_rmsnorm_qkv_prefill()

void fused_rmsnorm_qkv_prefill ( const float *  x,
const float *  gamma,
const float *  Wq,
const float *  Wk,
const float *  Wv,
float *  Q,
float *  K,
float *  V,
int  seq_len,
int  hidden,
int  q_dim,
int  kv_dim,
float  eps,
float *  scratch 
)

Fused RMSNorm + QKV projection for prefill.

Tiles along token dimension to keep intermediate x_norm in L2 cache. Avoids ~7MB DRAM traffic per layer for seq_len=1024, hidden=896.

Parameters
scratchTemporary buffer from fused_rmsnorm_qkv_scratch_size()

Fused RMSNorm + QKV projection for prefill.

KEY INSIGHT: For Qwen2-0.5B, all QKV weights fit in L3: Wq (896×896) + Wk (128×896) + Wv (128×896) = 4.1MB < 6MB L3

So we use M-tiling (tokens) only:

  1. For each token tile: a. Compute RMSNorm ONCE into scratch (x_norm stays in L2) b. Do all three GEMMs (Q, K, V) against cached x_norm c. Weights stay hot in L3 across all token tiles

This avoids both:

  • Large x_norm intermediate buffer (only TILE_M × hidden in L2)
  • RMSNorm recomputation (done once per token tile, used 3×)

Definition at line 393 of file prefill_fused_gemm.c.

408 {
409  /* scratch is x_norm tile: [TILE_M × hidden] fits in L2 */
410 
411  /* Process token tiles - weights stay in L3 across all tiles */
412  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
413  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
414  ? PREFILL_TILE_M : (seq_len - m_start);
415 
416  const float *x_tile = x + (size_t)m_start * hidden;
417 
418  /* Step 1: RMSNorm for this token tile (computed ONCE, used 3×) */
419  rmsnorm_tile(x_tile, gamma, scratch, tile_m, hidden, hidden, eps);
420 
421  /* Step 2: Q projection - x_norm is hot in L2, Wq hot in L3 */
422  float *Q_tile = Q + (size_t)m_start * q_dim;
423  gemm_tile_nt_strided(scratch, Wq, Q_tile, tile_m, q_dim, hidden, q_dim);
424 
425  /* Step 3: K projection - x_norm still hot, Wk displaces some Wq */
426  float *K_tile = K + (size_t)m_start * kv_dim;
427  gemm_tile_nt_strided(scratch, Wk, K_tile, tile_m, kv_dim, hidden, kv_dim);
428 
429  /* Step 4: V projection - x_norm still hot, Wv displaces Wk */
430  float *V_tile = V + (size_t)m_start * kv_dim;
431  gemm_tile_nt_strided(scratch, Wv, V_tile, tile_m, kv_dim, hidden, kv_dim);
432  }
433 }
static void rmsnorm_tile(const float *input, const float *gamma, float *output, int tile_m, int embed_dim, int aligned_embed_dim, float eps)
Compute RMSNorm for a tile of tokens.

References gemm_tile_nt_strided(), PREFILL_TILE_M, and rmsnorm_tile().

◆ fused_rmsnorm_qkv_prefill_head_major()

void fused_rmsnorm_qkv_prefill_head_major ( const float *  x,
const float *  gamma,
const float *  Wq,
const float *  Bq,
const float *  Wk,
const float *  Bk,
const float *  Wv,
const float *  Bv,
float *  Q,
float *  K,
float *  V,
int  seq_len,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
int  kv_stride_tokens,
float  eps,
float *  scratch 
)

Fused RMSNorm + QKV projection for prefill (head-major outputs)

Writes Q as [num_heads, seq_len, aligned_head_dim] and K/V with stride kv_stride_tokens for KV-cache compatibility.

Q is written as [num_heads, seq_len, aligned_head_dim]. K/V are written with kv_stride_tokens for KV-cache compatibility.

Definition at line 441 of file prefill_fused_gemm.c.

460 {
461  if (!x || !gamma || !Wq || !Wk || !Wv || !Q || !K || !V || !scratch) {
462  return;
463  }
464  if (seq_len <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
465  head_dim <= 0 || aligned_head_dim <= 0 ||
466  num_heads <= 0 || num_kv_heads <= 0) {
467  return;
468  }
469  if (kv_stride_tokens < seq_len) {
470  return;
471  }
472 
473  const size_t q_head_stride = (size_t)seq_len * (size_t)aligned_head_dim;
474  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
475  const size_t head_w_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
476 
477  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
478  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
479  ? PREFILL_TILE_M : (seq_len - m_start);
480 
481  const float *x_tile = x + (size_t)m_start * (size_t)aligned_embed_dim;
482  rmsnorm_tile(x_tile, gamma, scratch, tile_m, embed_dim, aligned_embed_dim, eps);
483 
484  for (int h = 0; h < num_heads; ++h) {
485  const float *wq_h = Wq + (size_t)h * head_w_stride;
486  const float *bq_h = Bq ? (Bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
487  float *q_h = Q + (size_t)h * q_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
488 
489  gemm_tile_nt_strided(scratch, wq_h, q_h,
490  tile_m, aligned_head_dim, aligned_embed_dim, aligned_head_dim);
491  add_bias_tile(q_h, bq_h, tile_m, aligned_head_dim);
492  }
493 
494  for (int h = 0; h < num_kv_heads; ++h) {
495  const float *wk_h = Wk + (size_t)h * head_w_stride;
496  const float *wv_h = Wv + (size_t)h * head_w_stride;
497  const float *bk_h = Bk ? (Bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
498  const float *bv_h = Bv ? (Bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
499  float *k_h = K + (size_t)h * kv_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
500  float *v_h = V + (size_t)h * kv_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
501 
502  gemm_tile_nt_strided(scratch, wk_h, k_h,
503  tile_m, aligned_head_dim, aligned_embed_dim, aligned_head_dim);
504  add_bias_tile(k_h, bk_h, tile_m, aligned_head_dim);
505 
506  gemm_tile_nt_strided(scratch, wv_h, v_h,
507  tile_m, aligned_head_dim, aligned_embed_dim, aligned_head_dim);
508  add_bias_tile(v_h, bv_h, tile_m, aligned_head_dim);
509  }
510  }
511 }

References add_bias_tile(), gemm_tile_nt_strided(), PREFILL_TILE_M, and rmsnorm_tile().

Referenced by mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().

◆ fused_rmsnorm_qkv_prefill_head_major_quant()

void fused_rmsnorm_qkv_prefill_head_major_quant ( const float *  x,
const float *  gamma,
const void *  Wq,
const float *  Bq,
CKDataType  wq_dt,
const void *  Wk,
const float *  Bk,
CKDataType  wk_dt,
const void *  Wv,
const float *  Bv,
CKDataType  wv_dt,
float *  Q,
float *  K,
float *  V,
int  seq_len,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
int  kv_stride_tokens,
float  eps,
void *  scratch 
)

Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)

Supports Q5_0 or Q8_0 weights with Q8_0 activations.

Supports Q5_0 or Q8_0 weights with Q8_0 activations. Writes K/V directly into KV cache layout (kv_stride_tokens).

Definition at line 519 of file prefill_fused_gemm.c.

538 {
539  if (!x || !gamma || !Wq || !Wk || !Wv || !Q || !K || !V || !scratch) {
540  return;
541  }
542  if (seq_len <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
543  head_dim <= 0 || aligned_head_dim <= 0 ||
544  num_heads <= 0 || num_kv_heads <= 0) {
545  return;
546  }
547  if (aligned_embed_dim % 32 != 0) {
548  return;
549  }
550  if (kv_stride_tokens < seq_len) {
551  return;
552  }
553  /* Determine quantization path: Q8_0 activations for Q5_0/Q8_0 weights,
554  * Q8_K activations for Q4_K/Q6_K weights. All QKV weights must use
555  * the same quantization family. */
556  int use_q8_k_path = qkv_q8_k_dtype_supported(wq_dt);
557  int use_q8_0_path = qkv_q8_0_dtype_supported(wq_dt);
558 
559  if (!use_q8_k_path && !use_q8_0_path) {
560  /* Unsupported dtype for wq */
561  return;
562  }
563 
564  /* Verify all dtypes are from the same family */
565  if (use_q8_k_path) {
566  if (!qkv_q8_k_dtype_supported(wk_dt) || !qkv_q8_k_dtype_supported(wv_dt)) {
567  return; /* Mixed Q8_K and Q8_0 paths not supported */
568  }
569  } else {
570  if (!qkv_q8_0_dtype_supported(wk_dt) || !qkv_q8_0_dtype_supported(wv_dt)) {
571  return;
572  }
573  }
574 
575  const size_t float_bytes = (size_t)PREFILL_TILE_M * (size_t)aligned_embed_dim * sizeof(float);
576  /* Q8_K has larger blocks (256) than Q8_0 (32), so use appropriate size */
577  const CKDataType act_quant_type = use_q8_k_path ? CK_DT_Q8_K : CK_DT_Q8_0;
578  const size_t q8_row_bytes = ck_dtype_row_bytes(act_quant_type, (size_t)aligned_embed_dim);
579  const size_t q8_bytes = (size_t)PREFILL_TILE_M * q8_row_bytes;
580  const size_t q8_offset = align_up_size(float_bytes, 64);
581 
582  float *normed = (float *)scratch;
583  uint8_t *q8_tile = (uint8_t *)scratch + q8_offset;
584  (void)q8_bytes;
585 
586  const size_t q_head_stride = (size_t)seq_len * (size_t)aligned_head_dim;
587  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
588  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
589  const size_t wq_head_bytes = ck_dtype_row_bytes(wq_dt, head_w_elems);
590  const size_t wk_head_bytes = ck_dtype_row_bytes(wk_dt, head_w_elems);
591  const size_t wv_head_bytes = ck_dtype_row_bytes(wv_dt, head_w_elems);
592 
593  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
594  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
595  ? PREFILL_TILE_M : (seq_len - m_start);
596 
597  const float *x_tile = x + (size_t)m_start * (size_t)aligned_embed_dim;
598  rmsnorm_tile(x_tile, gamma, normed, tile_m, embed_dim, aligned_embed_dim, eps);
599 
600  /* Quantize activations to appropriate format */
601  for (int t = 0; t < tile_m; ++t) {
602  const float *row = normed + (size_t)t * (size_t)aligned_embed_dim;
603  if (use_q8_k_path) {
604  quantize_row_q8_k(row,
605  q8_tile + (size_t)t * q8_row_bytes,
606  aligned_embed_dim);
607  } else {
608  quantize_row_q8_0(row,
609  q8_tile + (size_t)t * q8_row_bytes,
610  aligned_embed_dim);
611  }
612  }
613 
614  for (int h = 0; h < num_heads; ++h) {
615  const uint8_t *wq_h = (const uint8_t *)Wq + (size_t)h * wq_head_bytes;
616  const float *bq_h = Bq ? (Bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
617  float *q_h = Q + (size_t)h * q_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
618 
619  if (use_q8_k_path) {
620  gemm_nt_q8_k_qkv_dispatch(q8_tile, wq_h, bq_h, q_h,
621  tile_m, aligned_head_dim, aligned_embed_dim, wq_dt);
622  } else {
623  gemm_nt_q8_0_dispatch(q8_tile, wq_h, bq_h, q_h,
624  tile_m, aligned_head_dim, aligned_embed_dim, wq_dt);
625  }
626  }
627 
628  for (int h = 0; h < num_kv_heads; ++h) {
629  const uint8_t *wk_h = (const uint8_t *)Wk + (size_t)h * wk_head_bytes;
630  const uint8_t *wv_h = (const uint8_t *)Wv + (size_t)h * wv_head_bytes;
631  const float *bk_h = Bk ? (Bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
632  const float *bv_h = Bv ? (Bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
633  float *k_h = K + (size_t)h * kv_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
634  float *v_h = V + (size_t)h * kv_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
635 
636  if (use_q8_k_path) {
637  gemm_nt_q8_k_qkv_dispatch(q8_tile, wk_h, bk_h, k_h,
638  tile_m, aligned_head_dim, aligned_embed_dim, wk_dt);
639  gemm_nt_q8_k_qkv_dispatch(q8_tile, wv_h, bv_h, v_h,
640  tile_m, aligned_head_dim, aligned_embed_dim, wv_dt);
641  } else {
642  gemm_nt_q8_0_dispatch(q8_tile, wk_h, bk_h, k_h,
643  tile_m, aligned_head_dim, aligned_embed_dim, wk_dt);
644  gemm_nt_q8_0_dispatch(q8_tile, wv_h, bv_h, v_h,
645  tile_m, aligned_head_dim, aligned_embed_dim, wv_dt);
646  }
647  }
648  }
649 }
CKDataType
Supported data types in C-Kernel-Engine.
Definition: ckernel_dtype.h:27
static void gemm_nt_q8_0_dispatch(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
static int qkv_q8_k_dtype_supported(CKDataType dt)
static void gemm_nt_q8_k_qkv_dispatch(const void *A_q8k, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
static int qkv_q8_0_dtype_supported(CKDataType dt)

References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), gemm_nt_q8_0_dispatch(), gemm_nt_q8_k_qkv_dispatch(), PREFILL_TILE_M, qkv_q8_0_dtype_supported(), qkv_q8_k_dtype_supported(), quantize_row_q8_0(), quantize_row_q8_k(), and rmsnorm_tile().

Referenced by mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().

◆ fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size()

size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size ( int  aligned_embed_dim)

Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.

Definition at line 651 of file prefill_fused_gemm.c.

651  {
652  if (aligned_embed_dim <= 0) {
653  return 0;
654  }
655  const size_t float_bytes = (size_t)PREFILL_TILE_M * (size_t)aligned_embed_dim * sizeof(float);
656  /* Use max of Q8_0 and Q8_K sizes to support both paths */
657  const size_t q8_0_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0, (size_t)aligned_embed_dim);
658  const size_t q8_k_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_K, (size_t)aligned_embed_dim);
659  const size_t q8_row_bytes = (q8_k_row_bytes > q8_0_row_bytes) ? q8_k_row_bytes : q8_0_row_bytes;
660  const size_t q8_bytes = (size_t)PREFILL_TILE_M * q8_row_bytes;
661  return align_up_size(float_bytes, 64) + q8_bytes;
662 }

References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), and PREFILL_TILE_M.

Referenced by mega_fused_attention_prefill(), mega_fused_attention_prefill_q8_0(), mega_fused_attention_prefill_q8_0_scratch_size(), and mega_fused_attention_prefill_scratch_size().

◆ fused_rmsnorm_qkv_scratch_size()

size_t fused_rmsnorm_qkv_scratch_size ( int  hidden)

Get scratch buffer size for fused_rmsnorm_qkv_prefill.

Get scratch buffer size for fused_rmsnorm_qkv_prefill.

Definition at line 739 of file prefill_fused_gemm.c.

739  {
740  return (size_t)PREFILL_TILE_M * hidden * sizeof(float);
741 }

References PREFILL_TILE_M.

◆ geglu_backward_fp32()

void geglu_backward_fp32 ( const float *  x,
const float *  d_out,
float *  d_x,
int  tokens,
int  dim 
)

GeGLU backward pass (fp32)

Test:
test_geglu.py::TestGeGLU::test_geglu_backward_fp32

dL/dx given dL/d(out) where out = GELU(a) * b Chain rule: dL/da = dL/dout * d(GELU)/da * b dL/db = dL/dout * GELU(a)

After changes: make test

Definition at line 843 of file gelu_kernels.c.

848 {
849  const float sqrt_2_over_pi = 0.7978845608f;
850  const float coeff = 0.044715f;
851 
852  const int inner_dim = dim * 2;
853 
854  for (int t = 0; t < tokens; ++t) {
855  const float *x_ptr = x + (size_t)t * inner_dim;
856  const float *d_out_ptr = d_out + (size_t)t * dim;
857  float *d_x_ptr = d_x + (size_t)t * inner_dim;
858 
859  for (int d = 0; d < dim; ++d) {
860  float a = x_ptr[d];
861  float b = x_ptr[dim + d];
862  float dout = d_out_ptr[d];
863 
864  // GELU(a) derivative components
865  float a2 = a * a;
866  float a3 = a2 * a;
867  float g = sqrt_2_over_pi * (a + coeff * a3);
868  float tanh_g = tanhf(g);
869  float sech2_g = 1.0f - tanh_g * tanh_g;
870  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * a2);
871 
872  // d(GELU)/da = 0.5 * (1 + tanh(g)) + 0.5 * a * sech^2(g) * g'
873  float d_gelu = 0.5f * (1.0f + tanh_g) + 0.5f * a * sech2_g * g_prime;
874 
875  // dL/da = dL/dout * d(GELU)/da * b
876  d_x_ptr[d] = dout * d_gelu * b;
877 
878  // dL/db = dL/dout * GELU(a)
879  float gelu_a = 0.5f * a * (1.0f + tanh_g);
880  d_x_ptr[dim + d] = dout * gelu_a;
881  }
882  }
883 }

◆ geglu_forward_bf16()

void geglu_forward_bf16 ( const uint16_t *  x,
uint16_t *  out,
int  tokens,
int  dim,
float *  scratch 
)

GeGLU forward pass (bf16)

Test:
test_geglu.py::TestGeGLU::test_geglu_forward_bf16

BF16 version: converts to FP32, computes, converts back. Caller provides scratch buffer of size 3 * tokens * dim * sizeof(float).

Layout:

  • scratch[0 : 2*tokens*dim] = FP32 input [a, b]
  • scratch[2*tokens*dim : ...] = FP32 output

Note: We need separate buffers for input and output to avoid overlap when tokens > 1. The input is 2*dim per token, output is dim per token.

After changes: make test

Definition at line 813 of file gelu_kernels.c.

814 {
815  if (!x || !out || !scratch) return;
816 
817  const size_t fp32_size = (size_t)tokens * (size_t)dim;
818  const size_t input_size = fp32_size * 2; // [a, b] = 2*dim per token
819  float *fp32_input = scratch;
820  float *fp32_output = scratch + input_size;
821 
822  // Convert BF16 input to FP32
823  bf16_tensor_to_float(x, fp32_input, input_size);
824 
825  // Run FP32 GeGLU (output goes to separate buffer to avoid overlap)
826  geglu_forward_fp32(fp32_input, fp32_output, tokens, dim);
827 
828  // Convert FP32 output back to BF16
829  float_tensor_to_bf16(fp32_output, out, fp32_size);
830 }
void geglu_forward_fp32(const float *x, float *out, int tokens, int dim)
Definition: gelu_kernels.c:623

References bf16_tensor_to_float(), float_tensor_to_bf16(), and geglu_forward_fp32().

◆ geglu_forward_fp32()

void geglu_forward_fp32 ( const float *  x,
float *  out,
int  tokens,
int  dim 
)

GeGLU forward pass (fp32)

Test:
test_geglu.py::TestGeGLU::test_geglu_forward_fp32

Computes out = GELU(a) * b where x = [a, b] along last dimension. Input shape: [tokens, 2 * dim], Output shape: [tokens, dim]

After changes: make test

Definition at line 623 of file gelu_kernels.c.

624 {
625  const float sqrt_2_over_pi = 0.7978845608f;
626  const float coeff = 0.044715f;
627 
628  const int inner_dim = dim * 2;
629 
630 #if defined(__AVX512F__)
631  const __m512 sqrt_2_pi_vec = _mm512_set1_ps(sqrt_2_over_pi);
632  const __m512 coeff_vec = _mm512_set1_ps(coeff);
633  const __m512 half_vec = _mm512_set1_ps(0.5f);
634  const __m512 one_vec = _mm512_set1_ps(1.0f);
635 
636  for (int t = 0; t < tokens; ++t) {
637  const float *x_ptr = x + (size_t)t * inner_dim;
638  float *out_ptr = out + (size_t)t * dim;
639 
640  int d = 0;
641  // Process first half (a) with GELU, second half (b) directly
642  for (; d + 32 <= dim; d += 32) {
643  // Load a (first half of inner_dim)
644  __m512 a0 = _mm512_loadu_ps(&x_ptr[d]);
645  __m512 a1 = _mm512_loadu_ps(&x_ptr[d + 16]);
646 
647  // Compute GELU(a)
648  __m512 a0_sq = _mm512_mul_ps(a0, a0);
649  __m512 a0_cu = _mm512_mul_ps(a0_sq, a0);
650  __m512 a1_sq = _mm512_mul_ps(a1, a1);
651  __m512 a1_cu = _mm512_mul_ps(a1_sq, a1);
652 
653  // inner = sqrt(2/pi) * (a + 0.044715 * a^3)
654  __m512 inner0 = _mm512_fmadd_ps(coeff_vec, a0_cu, a0);
655  __m512 inner1 = _mm512_fmadd_ps(coeff_vec, a1_cu, a1);
656  inner0 = _mm512_mul_ps(sqrt_2_pi_vec, inner0);
657  inner1 = _mm512_mul_ps(sqrt_2_pi_vec, inner1);
658 
659  // tanh(inner)
660  __m512 tanh0 = tanh512_fast(inner0);
661  __m512 tanh1 = tanh512_fast(inner1);
662 
663  // GELU = 0.5 * a * (1 + tanh)
664  __m512 gelu0 = _mm512_mul_ps(half_vec, _mm512_mul_ps(a0, _mm512_add_ps(one_vec, tanh0)));
665  __m512 gelu1 = _mm512_mul_ps(half_vec, _mm512_mul_ps(a1, _mm512_add_ps(one_vec, tanh1)));
666 
667  // Load b (second half of inner_dim)
668  __m512 b0 = _mm512_loadu_ps(&x_ptr[dim + d]);
669  __m512 b1 = _mm512_loadu_ps(&x_ptr[dim + d + 16]);
670 
671  // out = GELU(a) * b
672  _mm512_storeu_ps(&out_ptr[d], _mm512_mul_ps(gelu0, b0));
673  _mm512_storeu_ps(&out_ptr[d + 16], _mm512_mul_ps(gelu1, b1));
674  }
675  // Handle remaining
676  for (; d < dim; ++d) {
677  float a = x_ptr[d];
678  float b = x_ptr[dim + d];
679  float a3 = a * a * a;
680  float inner = sqrt_2_over_pi * (a + coeff * a3);
681  float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
682  out_ptr[d] = gelu_a * b;
683  }
684  }
685 
686 #elif defined(__AVX2__)
687  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
688  const __m256 coeff_vec = _mm256_set1_ps(coeff);
689  const __m256 half_vec = _mm256_set1_ps(0.5f);
690  const __m256 one_vec = _mm256_set1_ps(1.0f);
691 
692  for (int t = 0; t < tokens; ++t) {
693  const float *x_ptr = x + (size_t)t * inner_dim;
694  float *out_ptr = out + (size_t)t * dim;
695 
696  int d = 0;
697  for (; d + 16 <= dim; d += 16) {
698  // Load a
699  __m256 a0 = _mm256_loadu_ps(&x_ptr[d]);
700  __m256 a1 = _mm256_loadu_ps(&x_ptr[d + 8]);
701 
702  // GELU(a)
703  __m256 a0_sq = _mm256_mul_ps(a0, a0);
704  __m256 a0_cu = _mm256_mul_ps(a0_sq, a0);
705  __m256 a1_sq = _mm256_mul_ps(a1, a1);
706  __m256 a1_cu = _mm256_mul_ps(a1_sq, a1);
707 
708  __m256 inner0 = _mm256_fmadd_ps(coeff_vec, a0_cu, a0);
709  __m256 inner1 = _mm256_fmadd_ps(coeff_vec, a1_cu, a1);
710  inner0 = _mm256_mul_ps(sqrt_2_pi_vec, inner0);
711  inner1 = _mm256_mul_ps(sqrt_2_pi_vec, inner1);
712 
713  __m256 tanh0 = tanh256_fast(inner0);
714  __m256 tanh1 = tanh256_fast(inner1);
715 
716  __m256 gelu0 = _mm256_mul_ps(half_vec, _mm256_mul_ps(a0, _mm256_add_ps(one_vec, tanh0)));
717  __m256 gelu1 = _mm256_mul_ps(half_vec, _mm256_mul_ps(a1, _mm256_add_ps(one_vec, tanh1)));
718 
719  // b
720  __m256 b0 = _mm256_loadu_ps(&x_ptr[dim + d]);
721  __m256 b1 = _mm256_loadu_ps(&x_ptr[dim + d + 8]);
722 
723  _mm256_storeu_ps(&out_ptr[d], _mm256_mul_ps(gelu0, b0));
724  _mm256_storeu_ps(&out_ptr[d + 8], _mm256_mul_ps(gelu1, b1));
725  }
726  for (; d < dim; ++d) {
727  float a = x_ptr[d];
728  float b = x_ptr[dim + d];
729  float a3 = a * a * a;
730  float inner = sqrt_2_over_pi * (a + coeff * a3);
731  float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
732  out_ptr[d] = gelu_a * b;
733  }
734  }
735 
736 #elif defined(__AVX__)
737  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
738  const __m256 coeff_vec = _mm256_set1_ps(coeff);
739  const __m256 half_vec = _mm256_set1_ps(0.5f);
740  const __m256 one_vec = _mm256_set1_ps(1.0f);
741 
742  float inner_arr[8] __attribute__((aligned(32)));
743  float tanh_arr[8] __attribute__((aligned(32)));
744 
745  for (int t = 0; t < tokens; ++t) {
746  const float *x_ptr = x + (size_t)t * inner_dim;
747  float *out_ptr = out + (size_t)t * dim;
748 
749  int d = 0;
750  for (; d + 8 <= dim; d += 8) {
751  __m256 a = _mm256_loadu_ps(&x_ptr[d]);
752  __m256 a_sq = _mm256_mul_ps(a, a);
753  __m256 a_cu = _mm256_mul_ps(a_sq, a);
754 
755  __m256 coeff_a_cu = _mm256_mul_ps(coeff_vec, a_cu);
756  __m256 inner = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(a, coeff_a_cu));
757 
758  _mm256_store_ps(inner_arr, inner);
759  for (int j = 0; j < 8; ++j) {
760  tanh_arr[j] = tanhf(inner_arr[j]);
761  }
762  __m256 tanh_val = _mm256_load_ps(tanh_arr);
763 
764  __m256 gelu = _mm256_mul_ps(half_vec, _mm256_mul_ps(a, _mm256_add_ps(one_vec, tanh_val)));
765  __m256 b = _mm256_loadu_ps(&x_ptr[dim + d]);
766 
767  _mm256_storeu_ps(&out_ptr[d], _mm256_mul_ps(gelu, b));
768  }
769  for (; d < dim; ++d) {
770  float a = x_ptr[d];
771  float b = x_ptr[dim + d];
772  float a3 = a * a * a;
773  float inner = sqrt_2_over_pi * (a + coeff * a3);
774  float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
775  out_ptr[d] = gelu_a * b;
776  }
777  }
778 
779 #else
780  // Scalar fallback
781  for (int t = 0; t < tokens; ++t) {
782  const float *x_ptr = x + (size_t)t * inner_dim;
783  float *out_ptr = out + (size_t)t * dim;
784 
785  for (int d = 0; d < dim; ++d) {
786  float a = x_ptr[d];
787  float b = x_ptr[dim + d];
788  float a3 = a * a * a;
789  float inner = sqrt_2_over_pi * (a + coeff * a3);
790  float gelu_a = 0.5f * a * (1.0f + tanhf(inner));
791  out_ptr[d] = gelu_a * b;
792  }
793  }
794 #endif
795 }

References __attribute__().

◆ gelu_backward_exact()

void gelu_backward_exact ( const float *  input,
const float *  d_output,
float *  d_input,
size_t  n 
)

Definition at line 257 of file gelu_kernels.c.

261 {
262  const float sqrt_2_over_pi = 0.7978845608f;
263  const float coeff = 0.044715f;
264 
265 #if defined(__AVX512F__)
266  const __m512 sqrt_2_pi_vec = _mm512_set1_ps(sqrt_2_over_pi);
267  const __m512 coeff_vec = _mm512_set1_ps(coeff);
268  const __m512 coeff3_vec = _mm512_set1_ps(3.0f * coeff);
269  const __m512 half_vec = _mm512_set1_ps(0.5f);
270  const __m512 one_vec = _mm512_set1_ps(1.0f);
271 
272  size_t i = 0;
273  for (; i + 16 <= n; i += 16) {
274  __m512 x = _mm512_loadu_ps(&input[i]);
275  __m512 dy = _mm512_loadu_ps(&d_output[i]);
276 
277  __m512 x2 = _mm512_mul_ps(x, x);
278  __m512 x3 = _mm512_mul_ps(x2, x);
279 
280  // g = sqrt(2/pi) * (x + 0.044715 * x^3)
281  __m512 g = _mm512_fmadd_ps(coeff_vec, x3, x);
282  g = _mm512_mul_ps(sqrt_2_pi_vec, g);
283 
284  __m512 tanh_g = tanh512_fast(g);
285 
286  // g' = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)
287  __m512 g_prime = _mm512_fmadd_ps(coeff3_vec, x2, one_vec);
288  g_prime = _mm512_mul_ps(sqrt_2_pi_vec, g_prime);
289 
290  // sech^2(g) = 1 - tanh^2(g)
291  __m512 sech2_g = _mm512_fnmadd_ps(tanh_g, tanh_g, one_vec);
292 
293  // gelu_derivative = 0.5 * (1 + tanh_g) + 0.5 * x * sech2_g * g_prime
294  __m512 term1 = _mm512_mul_ps(half_vec, _mm512_add_ps(one_vec, tanh_g));
295  __m512 term2 = _mm512_mul_ps(half_vec, _mm512_mul_ps(x, _mm512_mul_ps(sech2_g, g_prime)));
296  __m512 gelu_deriv = _mm512_add_ps(term1, term2);
297 
298  __m512 result = _mm512_mul_ps(dy, gelu_deriv);
299  _mm512_storeu_ps(&d_input[i], result);
300  }
301  // Handle remaining elements
302  for (; i < n; ++i) {
303  float x = input[i];
304  float x3 = x * x * x;
305  float g = sqrt_2_over_pi * (x + coeff * x3);
306  float tanh_g = tanhf(g);
307  float x2 = x * x;
308  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
309  float sech2_g = 1.0f - tanh_g * tanh_g;
310  float gelu_derivative = 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
311  d_input[i] = d_output[i] * gelu_derivative;
312  }
313 
314 #elif defined(__AVX2__)
315  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
316  const __m256 coeff_vec = _mm256_set1_ps(coeff);
317  const __m256 coeff3_vec = _mm256_set1_ps(3.0f * coeff);
318  const __m256 half_vec = _mm256_set1_ps(0.5f);
319  const __m256 one_vec = _mm256_set1_ps(1.0f);
320 
321  size_t i = 0;
322  for (; i + 8 <= n; i += 8) {
323  __m256 x = _mm256_loadu_ps(&input[i]);
324  __m256 dy = _mm256_loadu_ps(&d_output[i]);
325 
326  __m256 x2 = _mm256_mul_ps(x, x);
327  __m256 x3 = _mm256_mul_ps(x2, x);
328 
329  // g = sqrt(2/pi) * (x + 0.044715 * x^3)
330  __m256 g = _mm256_fmadd_ps(coeff_vec, x3, x);
331  g = _mm256_mul_ps(sqrt_2_pi_vec, g);
332 
333  __m256 tanh_g = tanh256_fast(g);
334 
335  // g' = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)
336  __m256 g_prime = _mm256_fmadd_ps(coeff3_vec, x2, one_vec);
337  g_prime = _mm256_mul_ps(sqrt_2_pi_vec, g_prime);
338 
339  // sech^2(g) = 1 - tanh^2(g)
340  __m256 sech2_g = _mm256_fnmadd_ps(tanh_g, tanh_g, one_vec);
341 
342  // gelu_derivative = 0.5 * (1 + tanh_g) + 0.5 * x * sech2_g * g_prime
343  __m256 term1 = _mm256_mul_ps(half_vec, _mm256_add_ps(one_vec, tanh_g));
344  __m256 term2 = _mm256_mul_ps(half_vec, _mm256_mul_ps(x, _mm256_mul_ps(sech2_g, g_prime)));
345  __m256 gelu_deriv = _mm256_add_ps(term1, term2);
346 
347  __m256 result = _mm256_mul_ps(dy, gelu_deriv);
348  _mm256_storeu_ps(&d_input[i], result);
349  }
350  // Handle remaining elements
351  for (; i < n; ++i) {
352  float x = input[i];
353  float x3 = x * x * x;
354  float g = sqrt_2_over_pi * (x + coeff * x3);
355  float tanh_g = tanhf(g);
356  float x2 = x * x;
357  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
358  float sech2_g = 1.0f - tanh_g * tanh_g;
359  float gelu_derivative = 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
360  d_input[i] = d_output[i] * gelu_derivative;
361  }
362 
363 #elif defined(__AVX__)
364  // AVX1: Vectorize arithmetic, use scalar tanh
365  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
366  const __m256 coeff_vec = _mm256_set1_ps(coeff);
367  const __m256 coeff3_vec = _mm256_set1_ps(3.0f * coeff);
368  const __m256 half_vec = _mm256_set1_ps(0.5f);
369  const __m256 one_vec = _mm256_set1_ps(1.0f);
370 
371  size_t i = 0;
372  float g_arr[8] __attribute__((aligned(32)));
373  float tanh_arr[8] __attribute__((aligned(32)));
374 
375  for (; i + 8 <= n; i += 8) {
376  __m256 x = _mm256_loadu_ps(&input[i]);
377  __m256 dy = _mm256_loadu_ps(&d_output[i]);
378 
379  __m256 x2 = _mm256_mul_ps(x, x);
380  __m256 x3 = _mm256_mul_ps(x2, x);
381 
382  // g = sqrt(2/pi) * (x + 0.044715 * x^3)
383  __m256 coeff_x3 = _mm256_mul_ps(coeff_vec, x3);
384  __m256 g = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(x, coeff_x3));
385 
386  // Compute tanh scalarly
387  _mm256_store_ps(g_arr, g);
388  for (int j = 0; j < 8; ++j) {
389  tanh_arr[j] = tanhf(g_arr[j]);
390  }
391  __m256 tanh_g = _mm256_load_ps(tanh_arr);
392 
393  // g' = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)
394  __m256 coeff3_x2 = _mm256_mul_ps(coeff3_vec, x2);
395  __m256 g_prime = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(one_vec, coeff3_x2));
396 
397  // sech^2(g) = 1 - tanh^2(g)
398  __m256 tanh_g_sq = _mm256_mul_ps(tanh_g, tanh_g);
399  __m256 sech2_g = _mm256_sub_ps(one_vec, tanh_g_sq);
400 
401  // gelu_derivative = 0.5 * (1 + tanh_g) + 0.5 * x * sech2_g * g_prime
402  __m256 term1 = _mm256_mul_ps(half_vec, _mm256_add_ps(one_vec, tanh_g));
403  __m256 term2 = _mm256_mul_ps(half_vec, _mm256_mul_ps(x, _mm256_mul_ps(sech2_g, g_prime)));
404  __m256 gelu_deriv = _mm256_add_ps(term1, term2);
405 
406  __m256 result = _mm256_mul_ps(dy, gelu_deriv);
407  _mm256_storeu_ps(&d_input[i], result);
408  }
409  // Handle remaining elements
410  for (; i < n; ++i) {
411  float x = input[i];
412  float x3 = x * x * x;
413  float g = sqrt_2_over_pi * (x + coeff * x3);
414  float tanh_g = tanhf(g);
415  float x2 = x * x;
416  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
417  float sech2_g = 1.0f - tanh_g * tanh_g;
418  float gelu_derivative = 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
419  d_input[i] = d_output[i] * gelu_derivative;
420  }
421 
422 #else
423  // Scalar fallback
424  for (size_t i = 0; i < n; ++i) {
425  float x = input[i];
426 
427  float x3 = x * x * x;
428  float g = sqrt_2_over_pi * (x + coeff * x3);
429  float tanh_g = tanhf(g);
430 
431  float x2 = x * x;
432  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
433 
434  float sech2_g = 1.0f - tanh_g * tanh_g;
435  float gelu_derivative =
436  0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
437 
438  d_input[i] = d_output[i] * gelu_derivative;
439  }
440 #endif
441 }

References __attribute__().

◆ gelu_backward_exact_bf16()

void gelu_backward_exact_bf16 ( const uint16_t *  input,
const uint16_t *  d_output,
uint16_t *  d_input,
size_t  n,
float *  scratch_input,
float *  scratch_d_output,
float *  scratch_d_input 
)

Definition at line 46 of file gelu_kernels_bf16.c.

53 {
54  if (!scratch_input || !scratch_d_output || !scratch_d_input) return;
55 
56  bf16_tensor_to_float(input, scratch_input, n);
57  bf16_tensor_to_float(d_output, scratch_d_output, n);
58 
59  // Use scalar exact version to avoid fast tanh approximation error
60  // accumulating with BF16 precision loss.
61  gelu_backward_scalar(scratch_input, scratch_d_output, scratch_d_input, n);
62 
63  float_tensor_to_bf16(scratch_d_input, d_input, n);
64 }
void gelu_backward_scalar(const float *input, const float *d_output, float *d_input, size_t n)
Definition: gelu_kernels.c:462

References bf16_tensor_to_float(), float_tensor_to_bf16(), and gelu_backward_scalar().

◆ gelu_backward_fast()

void gelu_backward_fast ( const float *  input,
const float *  d_output,
float *  d_input,
size_t  n 
)

Definition at line 486 of file gelu_kernels.c.

490 {
491  const float beta = 1.702f;
492 
493 #if defined(__AVX512F__)
494  const __m512 beta_vec = _mm512_set1_ps(beta);
495  const __m512 one_vec = _mm512_set1_ps(1.0f);
496  const __m512 neg_beta_vec = _mm512_set1_ps(-beta);
497 
498  size_t i = 0;
499  for (; i + 16 <= n; i += 16) {
500  __m512 x = _mm512_loadu_ps(&input[i]);
501  __m512 dy = _mm512_loadu_ps(&d_output[i]);
502 
503  // s = sigmoid(beta * x) = 1 / (1 + exp(-beta * x))
504  __m512 neg_beta_x = _mm512_mul_ps(neg_beta_vec, x);
505  __m512 exp_neg = exp512_fast(neg_beta_x);
506  __m512 s = _mm512_div_ps(one_vec, _mm512_add_ps(one_vec, exp_neg));
507 
508  // gelu_derivative = s * (1 + x * (1 - s) * beta)
509  __m512 one_minus_s = _mm512_sub_ps(one_vec, s);
510  __m512 inner = _mm512_fmadd_ps(_mm512_mul_ps(x, one_minus_s), beta_vec, one_vec);
511  __m512 gelu_deriv = _mm512_mul_ps(s, inner);
512 
513  __m512 result = _mm512_mul_ps(dy, gelu_deriv);
514  _mm512_storeu_ps(&d_input[i], result);
515  }
516  // Handle remaining elements
517  for (; i < n; ++i) {
518  float x = input[i];
519  float s = 1.0f / (1.0f + expf(-beta * x));
520  float gelu_derivative = s * (1.0f + x * (1.0f - s) * beta);
521  d_input[i] = d_output[i] * gelu_derivative;
522  }
523 
524 #elif defined(__AVX2__)
525  const __m256 beta_vec = _mm256_set1_ps(beta);
526  const __m256 one_vec = _mm256_set1_ps(1.0f);
527  const __m256 neg_beta_vec = _mm256_set1_ps(-beta);
528 
529  size_t i = 0;
530  for (; i + 8 <= n; i += 8) {
531  __m256 x = _mm256_loadu_ps(&input[i]);
532  __m256 dy = _mm256_loadu_ps(&d_output[i]);
533 
534  // s = sigmoid(beta * x) = 1 / (1 + exp(-beta * x))
535  __m256 neg_beta_x = _mm256_mul_ps(neg_beta_vec, x);
536  __m256 exp_neg = exp256_fast(neg_beta_x);
537  __m256 s = _mm256_div_ps(one_vec, _mm256_add_ps(one_vec, exp_neg));
538 
539  // gelu_derivative = s * (1 + x * (1 - s) * beta)
540  __m256 one_minus_s = _mm256_sub_ps(one_vec, s);
541  __m256 inner = _mm256_fmadd_ps(_mm256_mul_ps(x, one_minus_s), beta_vec, one_vec);
542  __m256 gelu_deriv = _mm256_mul_ps(s, inner);
543 
544  __m256 result = _mm256_mul_ps(dy, gelu_deriv);
545  _mm256_storeu_ps(&d_input[i], result);
546  }
547  // Handle remaining elements
548  for (; i < n; ++i) {
549  float x = input[i];
550  float s = 1.0f / (1.0f + expf(-beta * x));
551  float gelu_derivative = s * (1.0f + x * (1.0f - s) * beta);
552  d_input[i] = d_output[i] * gelu_derivative;
553  }
554 
555 #elif defined(__AVX__)
556  // AVX1: Vectorize arithmetic, use scalar exp
557  const __m256 beta_vec = _mm256_set1_ps(beta);
558  const __m256 one_vec = _mm256_set1_ps(1.0f);
559  const __m256 neg_beta_vec = _mm256_set1_ps(-beta);
560 
561  size_t i = 0;
562  float neg_beta_x_arr[8] __attribute__((aligned(32)));
563  float exp_arr[8] __attribute__((aligned(32)));
564 
565  for (; i + 8 <= n; i += 8) {
566  __m256 x = _mm256_loadu_ps(&input[i]);
567  __m256 dy = _mm256_loadu_ps(&d_output[i]);
568 
569  // s = sigmoid(beta * x) = 1 / (1 + exp(-beta * x))
570  __m256 neg_beta_x = _mm256_mul_ps(neg_beta_vec, x);
571 
572  // Compute exp scalarly
573  _mm256_store_ps(neg_beta_x_arr, neg_beta_x);
574  for (int j = 0; j < 8; ++j) {
575  exp_arr[j] = expf(neg_beta_x_arr[j]);
576  }
577  __m256 exp_neg = _mm256_load_ps(exp_arr);
578 
579  __m256 s = _mm256_div_ps(one_vec, _mm256_add_ps(one_vec, exp_neg));
580 
581  // gelu_derivative = s * (1 + x * (1 - s) * beta)
582  __m256 one_minus_s = _mm256_sub_ps(one_vec, s);
583  __m256 x_one_minus_s = _mm256_mul_ps(x, one_minus_s);
584  __m256 x_one_minus_s_beta = _mm256_mul_ps(x_one_minus_s, beta_vec);
585  __m256 inner = _mm256_add_ps(one_vec, x_one_minus_s_beta);
586  __m256 gelu_deriv = _mm256_mul_ps(s, inner);
587 
588  __m256 result = _mm256_mul_ps(dy, gelu_deriv);
589  _mm256_storeu_ps(&d_input[i], result);
590  }
591  // Handle remaining elements
592  for (; i < n; ++i) {
593  float x = input[i];
594  float s = 1.0f / (1.0f + expf(-beta * x));
595  float gelu_derivative = s * (1.0f + x * (1.0f - s) * beta);
596  d_input[i] = d_output[i] * gelu_derivative;
597  }
598 #endif
599 }

References __attribute__().

Referenced by gelu_backward_fast_bf16().

◆ gelu_backward_fast_bf16()

void gelu_backward_fast_bf16 ( const uint16_t *  input,
const uint16_t *  d_output,
uint16_t *  d_input,
size_t  n,
float *  scratch_input,
float *  scratch_d_output,
float *  scratch_d_input 
)

Definition at line 69 of file gelu_kernels_bf16.c.

76 {
77  if (!scratch_input || !scratch_d_output || !scratch_d_input) return;
78 
79  bf16_tensor_to_float(input, scratch_input, n);
80  bf16_tensor_to_float(d_output, scratch_d_output, n);
81 
82  gelu_backward_fast(scratch_input, scratch_d_output, scratch_d_input, n);
83 
84  float_tensor_to_bf16(scratch_d_input, d_input, n);
85 }
void gelu_backward_fast(const float *input, const float *d_output, float *d_input, size_t n)
Definition: gelu_kernels.c:486

References bf16_tensor_to_float(), float_tensor_to_bf16(), and gelu_backward_fast().

◆ gelu_backward_scalar()

void gelu_backward_scalar ( const float *  input,
const float *  d_output,
float *  d_input,
size_t  n 
)

Definition at line 462 of file gelu_kernels.c.

466 {
467  const float sqrt_2_over_pi = 0.7978845608f;
468  const float coeff = 0.044715f;
469 
470  for (size_t i = 0; i < n; ++i) {
471  float x = input[i];
472  float x3 = x * x * x;
473  float g = sqrt_2_over_pi * (x + coeff * x3);
474  float tanh_g = tanhf(g);
475  float x2 = x * x;
476  float g_prime = sqrt_2_over_pi * (1.0f + 3.0f * coeff * x2);
477  float sech2_g = 1.0f - tanh_g * tanh_g;
478  float gelu_derivative = 0.5f * (1.0f + tanh_g) + 0.5f * x * sech2_g * g_prime;
479  d_input[i] = d_output[i] * gelu_derivative;
480  }
481 }

Referenced by gelu_backward_exact_bf16().

◆ gelu_exact_inplace()

void gelu_exact_inplace ( float *  data,
size_t  n 
)

Definition at line 446 of file gelu_kernels.c.

447 {
448  const float sqrt_2_over_pi = 0.7978845608f;
449  const float coeff = 0.044715f;
450 
451  for (size_t i = 0; i < n; ++i) {
452  float x = data[i];
453  float x3 = x * x * x;
454  float inner = sqrt_2_over_pi * (x + coeff * x3);
455  data[i] = 0.5f * x * (1.0f + tanhf(inner));
456  }
457 }

Referenced by gelu_fast_inplace_bf16(), and mlp_token_parallel_exact().

◆ gelu_fast_inplace()

void gelu_fast_inplace ( float *  data,
size_t  n 
)

GELU activation forward (fast approximation, in-place)

Test:

test_gelu.py::TestGELUForward::test_gelu_fast_inplace

test_gelu.py::TestGELUForward::test_gelu_vs_exact

test_parity.py::test_gelu_parity

Fast GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) In-place on contiguous buffer.

After changes: make test && make llamacpp-parity-full

Definition at line 132 of file gelu_kernels.c.

133 {
134  const float sqrt_2_over_pi = 0.7978845608f;
135  const float coeff = 0.044715f;
136 
137 #if defined(__AVX512F__)
138  const __m512 sqrt_2_pi_vec = _mm512_set1_ps(sqrt_2_over_pi);
139  const __m512 coeff_vec = _mm512_set1_ps(coeff);
140  const __m512 half_vec = _mm512_set1_ps(0.5f);
141  const __m512 one_vec = _mm512_set1_ps(1.0f);
142 
143  size_t i = 0;
144  for (; i + 16 <= n; i += 16) {
145  __m512 x = _mm512_loadu_ps(&data[i]);
146  __m512 x2 = _mm512_mul_ps(x, x);
147  __m512 x3 = _mm512_mul_ps(x2, x);
148 
149  // inner = sqrt(2/pi) * (x + 0.044715 * x^3)
150  __m512 inner = _mm512_fmadd_ps(coeff_vec, x3, x);
151  inner = _mm512_mul_ps(sqrt_2_pi_vec, inner);
152 
153  // result = 0.5 * x * (1 + tanh(inner))
154  __m512 tanh_val = tanh512_fast(inner);
155  __m512 one_plus_tanh = _mm512_add_ps(one_vec, tanh_val);
156  __m512 result = _mm512_mul_ps(half_vec, _mm512_mul_ps(x, one_plus_tanh));
157 
158  _mm512_storeu_ps(&data[i], result);
159  }
160  // Handle remaining elements
161  for (; i < n; ++i) {
162  float x = data[i];
163  float x3 = x * x * x;
164  float inner = sqrt_2_over_pi * (x + coeff * x3);
165  data[i] = 0.5f * x * (1.0f + tanhf(inner));
166  }
167 
168 #elif defined(__AVX2__)
169  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
170  const __m256 coeff_vec = _mm256_set1_ps(coeff);
171  const __m256 half_vec = _mm256_set1_ps(0.5f);
172  const __m256 one_vec = _mm256_set1_ps(1.0f);
173 
174  size_t i = 0;
175  for (; i + 8 <= n; i += 8) {
176  __m256 x = _mm256_loadu_ps(&data[i]);
177  __m256 x2 = _mm256_mul_ps(x, x);
178  __m256 x3 = _mm256_mul_ps(x2, x);
179 
180  // inner = sqrt(2/pi) * (x + 0.044715 * x^3)
181  __m256 inner = _mm256_fmadd_ps(coeff_vec, x3, x);
182  inner = _mm256_mul_ps(sqrt_2_pi_vec, inner);
183 
184  // result = 0.5 * x * (1 + tanh(inner))
185  __m256 tanh_val = tanh256_fast(inner);
186  __m256 one_plus_tanh = _mm256_add_ps(one_vec, tanh_val);
187  __m256 result = _mm256_mul_ps(half_vec, _mm256_mul_ps(x, one_plus_tanh));
188 
189  _mm256_storeu_ps(&data[i], result);
190  }
191  // Handle remaining elements
192  for (; i < n; ++i) {
193  float x = data[i];
194  float x3 = x * x * x;
195  float inner = sqrt_2_over_pi * (x + coeff * x3);
196  data[i] = 0.5f * x * (1.0f + tanhf(inner));
197  }
198 
199 #elif defined(__AVX__)
200  // AVX1: Vectorize arithmetic, use scalar tanh
201  const __m256 sqrt_2_pi_vec = _mm256_set1_ps(sqrt_2_over_pi);
202  const __m256 coeff_vec = _mm256_set1_ps(coeff);
203  const __m256 half_vec = _mm256_set1_ps(0.5f);
204  const __m256 one_vec = _mm256_set1_ps(1.0f);
205 
206  size_t i = 0;
207  float inner_arr[8] __attribute__((aligned(32)));
208  float tanh_arr[8] __attribute__((aligned(32)));
209 
210  for (; i + 8 <= n; i += 8) {
211  __m256 x = _mm256_loadu_ps(&data[i]);
212  __m256 x2 = _mm256_mul_ps(x, x);
213  __m256 x3 = _mm256_mul_ps(x2, x);
214 
215  // inner = sqrt(2/pi) * (x + 0.044715 * x^3)
216  __m256 coeff_x3 = _mm256_mul_ps(coeff_vec, x3);
217  __m256 inner = _mm256_mul_ps(sqrt_2_pi_vec, _mm256_add_ps(x, coeff_x3));
218 
219  // Compute tanh scalarly
220  _mm256_store_ps(inner_arr, inner);
221  for (int j = 0; j < 8; ++j) {
222  tanh_arr[j] = tanhf(inner_arr[j]);
223  }
224  __m256 tanh_val = _mm256_load_ps(tanh_arr);
225 
226  // result = 0.5 * x * (1 + tanh(inner))
227  __m256 one_plus_tanh = _mm256_add_ps(one_vec, tanh_val);
228  __m256 result = _mm256_mul_ps(half_vec, _mm256_mul_ps(x, one_plus_tanh));
229 
230  _mm256_storeu_ps(&data[i], result);
231  }
232  // Handle remaining elements
233  for (; i < n; ++i) {
234  float x = data[i];
235  float x3 = x * x * x;
236  float inner = sqrt_2_over_pi * (x + coeff * x3);
237  data[i] = 0.5f * x * (1.0f + tanhf(inner));
238  }
239 
240 #else
241  // Scalar fallback
242  for (size_t i = 0; i < n; ++i) {
243  float x = data[i];
244  float x3 = x * x * x;
245  float inner = sqrt_2_over_pi * (x + coeff * x3);
246  data[i] = 0.5f * x * (1.0f + tanhf(inner));
247  }
248 #endif
249 }

References __attribute__().

Referenced by mlp_token_parallel().

◆ gelu_fast_inplace_bf16()

void gelu_fast_inplace_bf16 ( uint16_t *  data,
size_t  n,
float *  scratch 
)

Definition at line 31 of file gelu_kernels_bf16.c.

32 {
33  if (!scratch) return;
34 
35  bf16_tensor_to_float(data, scratch, n);
36  // Use exact version to avoid fast tanh approximation error accumulating
37  // with BF16 precision loss. Conversion overhead dominates anyway.
38  gelu_exact_inplace(scratch, n);
39  float_tensor_to_bf16(scratch, data, n);
40 }
void gelu_exact_inplace(float *data, size_t n)
Definition: gelu_kernels.c:446

References bf16_tensor_to_float(), float_tensor_to_bf16(), and gelu_exact_inplace().

◆ gemm_avx512_parallel()

void gemm_avx512_parallel ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 149 of file gemm_kernels.c.

154 {
155  if (ck_strict_parity_enabled()) {
156  gemm_naive_serial_double(A, B, bias, C, M, N, K);
157  return;
158  }
159 #if defined(__AVX512F__)
160 #pragma omp parallel for
161  for (int i = 0; i < M; i++) {
162  for (int j = 0; j < N; j++) {
163  __m512 sum_vec = _mm512_setzero_ps();
164  int k;
165  for (k = 0; k <= K - 16; k += 16) {
166  __m512 a_vec = _mm512_loadu_ps(&A[i * K + k]);
167  __m512 b_vec = _mm512_loadu_ps(&B[j * K + k]);
168  sum_vec = _mm512_fmadd_ps(a_vec, b_vec, sum_vec);
169  }
170  float sum = _mm512_reduce_add_ps(sum_vec);
171  for (; k < K; k++) {
172  sum += A[i * K + k] * B[j * K + k];
173  }
174  float bias_val = bias ? bias[j] : 0.0f;
175  C[i * N + j] = sum + bias_val;
176  }
177  }
178 #elif defined(__AVX__)
179  // AVX1 path: 256-bit vectors, no FMA (use mul + add)
180 #pragma omp parallel for
181  for (int i = 0; i < M; i++) {
182  for (int j = 0; j < N; j++) {
183  __m256 sum_vec = _mm256_setzero_ps();
184  int k;
185  for (k = 0; k <= K - 8; k += 8) {
186  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
187  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
188  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
189  sum_vec = _mm256_add_ps(sum_vec, prod);
190  }
191  float sum = hsum256_ps(sum_vec);
192  for (; k < K; k++) {
193  sum += A[i * K + k] * B[j * K + k];
194  }
195  float bias_val = bias ? bias[j] : 0.0f;
196  C[i * N + j] = sum + bias_val;
197  }
198  }
199 #else
200  gemm_naive_parallel(A, B, bias, C, M, N, K);
201 #endif
202 }
int ck_strict_parity_enabled(void)
static void gemm_naive_serial_double(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:107
void gemm_naive_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:125
#define C(color)
Definition: show_config.c:39

References C, ck_strict_parity_enabled(), gemm_naive_parallel(), and gemm_naive_serial_double().

◆ gemm_bias_gelu_fused()

void gemm_bias_gelu_fused ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 131 of file gemm_fused_kernels.c.

136 {
137 #if defined(__AVX__)
138 #pragma omp parallel for
139  for (int i = 0; i < M; i++) {
140  for (int j = 0; j < N; j++) {
141  __m256 sum_vec = _mm256_setzero_ps();
142  int k;
143  for (k = 0; k <= K - 8; k += 8) {
144  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
145  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
146  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
147  sum_vec = _mm256_add_ps(sum_vec, prod);
148  }
149  float sum = hsum256_ps_fused(sum_vec);
150  for (; k < K; k++) {
151  sum += A[i * K + k] * B[j * K + k];
152  }
153  sum += bias[j];
154  C[i * N + j] = fast_gelu_scalar(sum);
155  }
156  }
157 #else
158 #pragma omp parallel for
159  for (int i = 0; i < M; i++) {
160  for (int j = 0; j < N; j++) {
161  float sum = 0.0f;
162  for (int k = 0; k < K; k++) {
163  sum += A[i * K + k] * B[j * K + k];
164  }
165  sum += bias[j];
166  C[i * N + j] = fast_gelu_scalar(sum);
167  }
168  }
169 #endif
170 }
static float fast_gelu_scalar(float x)
static float hsum256_ps_fused(__m256 v)

References C, fast_gelu_scalar(), and hsum256_ps_fused().

◆ gemm_bias_relu_fused()

void gemm_bias_relu_fused ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 84 of file gemm_fused_kernels.c.

89 {
90 #if defined(__AVX__)
91 #pragma omp parallel for
92  for (int i = 0; i < M; i++) {
93  for (int j = 0; j < N; j++) {
94  __m256 sum_vec = _mm256_setzero_ps();
95  int k;
96  for (k = 0; k <= K - 8; k += 8) {
97  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
98  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
99  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
100  sum_vec = _mm256_add_ps(sum_vec, prod);
101  }
102  float sum = hsum256_ps_fused(sum_vec);
103  for (; k < K; k++) {
104  sum += A[i * K + k] * B[j * K + k];
105  }
106  // Fused: add bias and ReLU while still in register
107  sum += bias[j];
108  C[i * N + j] = sum > 0.0f ? sum : 0.0f;
109  }
110  }
111 #else
112 #pragma omp parallel for
113  for (int i = 0; i < M; i++) {
114  for (int j = 0; j < N; j++) {
115  float sum = 0.0f;
116  for (int k = 0; k < K; k++) {
117  sum += A[i * K + k] * B[j * K + k];
118  }
119  sum += bias[j];
120  C[i * N + j] = sum > 0.0f ? sum : 0.0f;
121  }
122  }
123 #endif
124 }

References C, and hsum256_ps_fused().

◆ gemm_bias_silu_fused()

void gemm_bias_silu_fused ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 177 of file gemm_fused_kernels.c.

182 {
183 #if defined(__AVX__)
184 #pragma omp parallel for
185  for (int i = 0; i < M; i++) {
186  for (int j = 0; j < N; j++) {
187  __m256 sum_vec = _mm256_setzero_ps();
188  int k;
189  for (k = 0; k <= K - 8; k += 8) {
190  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
191  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
192  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
193  sum_vec = _mm256_add_ps(sum_vec, prod);
194  }
195  float sum = hsum256_ps_fused(sum_vec);
196  for (; k < K; k++) {
197  sum += A[i * K + k] * B[j * K + k];
198  }
199  sum += bias[j];
200  // SiLU: x * sigmoid(x)
201  float sig = 1.0f / (1.0f + expf(-sum));
202  C[i * N + j] = sum * sig;
203  }
204  }
205 #else
206 #pragma omp parallel for
207  for (int i = 0; i < M; i++) {
208  for (int j = 0; j < N; j++) {
209  float sum = 0.0f;
210  for (int k = 0; k < K; k++) {
211  sum += A[i * K + k] * B[j * K + k];
212  }
213  sum += bias[j];
214  float sig = 1.0f / (1.0f + expf(-sum));
215  C[i * N + j] = sum * sig;
216  }
217  }
218 #endif
219 }

References C, and hsum256_ps_fused().

◆ gemm_blocked_serial()

void gemm_blocked_serial ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 661 of file gemm_kernels.c.

666 {
667  // Ensure threads are initialized (auto-detects on first call)
668  (void)ck_get_num_threads();
669 
670  if (ck_strict_parity_enabled()) {
671  gemm_naive_serial_double(A, B, bias, C, M, N, K);
672  return;
673  }
674 
675  // Decode-time matvec (M=1) is extremely common and benefits from parallelism over N.
676  // Lower threshold to parallelize more ops; OpenMP overhead is ~1-2μs per barrier.
677  // For N*K >= 64K elements, parallel is worthwhile.
678  if (M == 1 && (size_t)N * (size_t)K >= 65536) {
679  gemm_nt_matvec_parallel(A, B, bias, C, N, K);
680  return;
681  }
682 
683  /*
684  * Use gemm_microkernel for large matrices - it uses MKL/oneDNN when available,
685  * which is substantially faster than our hand-written SIMD kernels.
686  * B is stored as [N x K] (transposed), so we pass B_transposed=1.
687  * Note: Use threshold of 32 to avoid numerical precision issues with small matrices.
688  */
689  if (M >= 32 && N >= 32 && K >= 32) {
690  gemm_microkernel(A, B, C, M, N, K, 1); // B_transposed=1
691  ck_gemm_add_bias(C, bias, M, N);
692  return;
693  }
694 #if defined(__AVX512F__)
695  const int block_size = 64;
696 #elif defined(__AVX__)
697  const int block_size = 32;
698 #else
699  const int block_size = 32;
700 #endif
701  for (int i = 0; i < M; i++) {
702  for (int j = 0; j < N; j++) {
703  C[i * N + j] = bias ? bias[j] : 0.0f;
704  }
705  }
706  for (int ii = 0; ii < M; ii += block_size) {
707  for (int jj = 0; jj < N; jj += block_size) {
708  for (int kk = 0; kk < K; kk += block_size) {
709  int i_end = ck_min(ii + block_size, M);
710  int j_end = ck_min(jj + block_size, N);
711  int k_end = ck_min(kk + block_size, K);
712 
713  for (int i = ii; i < i_end; i++) {
714  for (int j = jj; j < j_end; j++) {
715 #if defined(__AVX512F__)
716  __m512 sum_vec = _mm512_setzero_ps();
717  int k;
718  for (k = kk; k <= k_end - 16; k += 16) {
719  __m512 a_vec = _mm512_loadu_ps(&A[i * K + k]);
720  __m512 b_vec = _mm512_loadu_ps(&B[j * K + k]);
721  sum_vec = _mm512_fmadd_ps(a_vec, b_vec, sum_vec);
722  }
723  float partial_sum = _mm512_reduce_add_ps(sum_vec);
724  for (; k < k_end; k++) {
725  partial_sum += A[i * K + k] * B[j * K + k];
726  }
727 #elif defined(__AVX__)
728  __m256 sum_vec = _mm256_setzero_ps();
729  int k;
730  for (k = kk; k <= k_end - 8; k += 8) {
731  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
732  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
733  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
734  sum_vec = _mm256_add_ps(sum_vec, prod);
735  }
736  float partial_sum = hsum256_ps(sum_vec);
737  for (; k < k_end; k++) {
738  partial_sum += A[i * K + k] * B[j * K + k];
739  }
740 #else
741  float partial_sum = 0.0f;
742  for (int k = kk; k < k_end; k++) {
743  partial_sum += A[i * K + k] * B[j * K + k];
744  }
745 #endif
746  C[i * N + j] += partial_sum;
747  }
748  }
749  }
750  }
751  }
752 }
void gemm_microkernel(const float *A, const float *B, float *C, int M, int N, int K, int B_transposed)
int ck_get_num_threads(void)
static int ck_min(int a, int b)
Definition: gemm_kernels.c:26
static void gemm_nt_matvec_parallel(const float *A, const float *B, const float *bias, float *C, int N, int K)
Definition: gemm_kernels.c:61
static void ck_gemm_add_bias(float *C, const float *bias, int M, int N)
Definition: gemm_kernels.c:28

References C, ck_gemm_add_bias(), ck_get_num_threads(), ck_min(), ck_strict_parity_enabled(), gemm_microkernel(), gemm_naive_serial_double(), and gemm_nt_matvec_parallel().

Referenced by ck_attention_project_head_major(), ck_gemm_nt_quant(), ck_mlp_swiglu_forward(), ck_mlp_swiglu_forward_fused_token(), ck_qkv_project_head_major(), ck_qkv_project_head_major_token(), mlp_token_parallel(), and mlp_token_parallel_exact().

◆ gemm_blocked_serial_bf16()

void gemm_blocked_serial_bf16 ( const uint16_t *  A,
const uint16_t *  B,
const uint16_t *  bias,
uint16_t *  C,
int  M,
int  N,
int  K 
)

Definition at line 272 of file gemm_kernels_bf16.c.

277 {
278  if (!A || !B || !C || M <= 0 || N <= 0 || K <= 0) {
279  return;
280  }
281 
282 #if HAVE_NATIVE_BF16
283  /* Native BF16 instructions available (Ice Lake / Sapphire Rapids+) */
284  gemm_bf16_native(A, B, bias, C, M, N, K);
285 #elif defined(__AVX512F__)
286  /* Use AVX-512F with software BF16 conversion */
287  if (M * N > 4096) {
288  gemm_bf16_blocked_avx512(A, B, bias, C, M, N, K);
289  } else {
290  gemm_bf16_avx512(A, B, bias, C, M, N, K);
291  }
292 #else
293  /* Scalar fallback */
294  gemm_bf16_scalar(A, B, bias, C, M, N, K);
295 #endif
296 }

References C.

◆ gemm_fine_grained_parallel()

void gemm_fine_grained_parallel ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 205 of file gemm_kernels.c.

210 {
211  if (ck_strict_parity_enabled()) {
212  gemm_naive_serial_double(A, B, bias, C, M, N, K);
213  return;
214  }
215 #if defined(__AVX512F__)
216  const int block_size = 64;
217 #pragma omp parallel for
218  for (int i = 0; i < M; i++) {
219  for (int j = 0; j < N; j++) {
220  C[i * N + j] = bias ? bias[j] : 0.0f;
221  }
222  }
223 #pragma omp parallel for collapse(3)
224  for (int ii = 0; ii < M; ii += block_size) {
225  for (int jj = 0; jj < N; jj += block_size) {
226  for (int kk = 0; kk < K; kk += block_size) {
227  int i_end = ck_min(ii + block_size, M);
228  int j_end = ck_min(jj + block_size, N);
229  int k_end = ck_min(kk + block_size, K);
230 
231  for (int i = ii; i < i_end; i++) {
232  for (int j = jj; j < j_end; j++) {
233  __m512 sum_vec = _mm512_setzero_ps();
234  int k;
235  for (k = kk; k <= k_end - 16; k += 16) {
236  __m512 a_vec = _mm512_loadu_ps(&A[i * K + k]);
237  __m512 b_vec = _mm512_loadu_ps(&B[j * K + k]);
238  sum_vec = _mm512_fmadd_ps(a_vec, b_vec, sum_vec);
239  }
240  float partial_sum = _mm512_reduce_add_ps(sum_vec);
241  for (; k < k_end; k++) {
242  partial_sum += A[i * K + k] * B[j * K + k];
243  }
244 #pragma omp atomic
245  C[i * N + j] += partial_sum;
246  }
247  }
248  }
249  }
250  }
251 #elif defined(__AVX__)
252  // AVX1 cache-blocked version
253  const int block_size = 32; // Smaller block for L1 cache
254 #pragma omp parallel for
255  for (int i = 0; i < M; i++) {
256  for (int j = 0; j < N; j++) {
257  C[i * N + j] = bias ? bias[j] : 0.0f;
258  }
259  }
260 #pragma omp parallel for collapse(3)
261  for (int ii = 0; ii < M; ii += block_size) {
262  for (int jj = 0; jj < N; jj += block_size) {
263  for (int kk = 0; kk < K; kk += block_size) {
264  int i_end = ck_min(ii + block_size, M);
265  int j_end = ck_min(jj + block_size, N);
266  int k_end = ck_min(kk + block_size, K);
267 
268  for (int i = ii; i < i_end; i++) {
269  for (int j = jj; j < j_end; j++) {
270  __m256 sum_vec = _mm256_setzero_ps();
271  int k;
272  for (k = kk; k <= k_end - 8; k += 8) {
273  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
274  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
275  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
276  sum_vec = _mm256_add_ps(sum_vec, prod);
277  }
278  float partial_sum = hsum256_ps(sum_vec);
279  for (; k < k_end; k++) {
280  partial_sum += A[i * K + k] * B[j * K + k];
281  }
282 #pragma omp atomic
283  C[i * N + j] += partial_sum;
284  }
285  }
286  }
287  }
288  }
289 #else
290  gemm_naive_parallel(A, B, bias, C, M, N, K);
291 #endif
292 }

References C, ck_min(), ck_strict_parity_enabled(), gemm_naive_parallel(), and gemm_naive_serial_double().

◆ gemm_microkernel()

void gemm_microkernel ( const float *  A,
const float *  B,
float *  C,
int  M,
int  N,
int  K,
int  B_transposed 
)

Definition at line 1134 of file gemm_microkernel.c.

1141 {
1142  if (B_transposed) {
1143  gemm_microkernel_blocked_bt(A, B, C, M, N, K);
1144  } else {
1145  // Use packed version for large matrices
1146  if (M >= PACK_THRESHOLD && N >= PACK_THRESHOLD && K >= PACK_THRESHOLD) {
1147  gemm_microkernel_packed(A, B, C, M, N, K);
1148  } else {
1149  gemm_microkernel_blocked(A, B, C, M, N, K);
1150  }
1151  }
1152 }
#define PACK_THRESHOLD
void gemm_microkernel_blocked(const float *A, const float *B, float *C, int M, int N, int K)
void gemm_microkernel_blocked_bt(const float *A, const float *B, float *C, int M, int N, int K)
void gemm_microkernel_packed(const float *A, const float *B, float *C, int M, int N, int K)

References C, gemm_microkernel_blocked(), gemm_microkernel_blocked_bt(), gemm_microkernel_packed(), and PACK_THRESHOLD.

Referenced by gemm_blocked_serial().

◆ gemm_microkernel_blocked()

void gemm_microkernel_blocked ( const float *  A,
const float *  B,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 934 of file gemm_microkernel.c.

940 {
941  const int mr = MR;
942  const int nr = NR;
943 
944  // Use sequential version for small matrices to avoid OpenMP overhead
945  // Threshold tuned for typical 4-8 core systems
946  if ((size_t)M * N * K <= 512ULL * 512 * 512) {
947  gemm_microkernel_sequential(A, B, C, M, N, K);
948  return;
949  }
950 
951  // Initialize thread count to physical cores (once)
953 
954  // Zero output first
955  #pragma omp parallel for schedule(static)
956  for (int i = 0; i < M; i++) {
957  memset(&C[i * N], 0, N * sizeof(float));
958  }
959 
960  // Block over K (outermost - for accumulation across all threads)
961  for (int k0 = 0; k0 < K; k0 += KC) {
962  int kb = (k0 + KC <= K) ? KC : (K - k0);
963  int first_k = (k0 == 0);
964 
965  // Parallelize over M rows - each thread gets a chunk of M
966  // This gives better cache locality than tile-level parallelism
967  #pragma omp parallel for schedule(static)
968  for (int m0 = 0; m0 < M; m0 += mr) {
969  int mr_actual = (m0 + mr <= M) ? mr : (M - m0);
970 
971  // Each thread processes all N tiles for its M rows
972  for (int n0 = 0; n0 < N; n0 += nr) {
973  int nr_actual = (n0 + nr <= N) ? nr : (N - n0);
974 
975  const float *A_tile = &A[m0 * K + k0];
976  const float *B_tile = &B[k0 * N + n0];
977  float *C_tile = &C[m0 * N + n0];
978 
979  if (mr_actual == mr && nr_actual == nr) {
980 #if defined(__AVX512F__)
981  gemm_microkernel_6x32_avx512(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
982 #elif defined(__FMA__)
983  gemm_microkernel_6x16_avx(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
984 #elif defined(__AVX__)
985  gemm_microkernel_4x16_avx(kb, A_tile, K, B_tile, N, C_tile, N, first_k);
986 #else
987  gemm_microkernel_edge(mr_actual, nr_actual, kb, A_tile, K, B_tile, N, C_tile, N, first_k);
988 #endif
989  } else {
990  gemm_microkernel_edge(mr_actual, nr_actual, kb, A_tile, K, B_tile, N, C_tile, N, first_k);
991  }
992  }
993  }
994  }
995 }
static void gemm_microkernel_edge(int m, int n, int K, const float *A, int lda, const float *B, int ldb, float *C, int ldc, int first_k)
static void gemm_microkernel_sequential(const float *A, const float *B, float *C, int M, int N, int K)
static void gemm_init_threads(void)
#define KC
#define MR
#define NR

References C, gemm_init_threads(), gemm_microkernel_edge(), gemm_microkernel_sequential(), KC, MR, and NR.

Referenced by gemm_microkernel(), and gemm_microkernel_packed().

◆ gemm_microkernel_blocked_bt()

void gemm_microkernel_blocked_bt ( const float *  A,
const float *  B,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 1058 of file gemm_microkernel.c.

1064 {
1065  // Zero output first
1066  #pragma omp parallel for schedule(static)
1067  for (int i = 0; i < M; i++) {
1068  memset(&C[i * N], 0, N * sizeof(float));
1069  }
1070 
1071  const int mr = MR;
1072  const int nr = NR;
1073 
1074  #pragma omp parallel for schedule(dynamic) collapse(2)
1075  for (int m0 = 0; m0 < M; m0 += MC) {
1076  for (int n0 = 0; n0 < N; n0 += NC) {
1077  int mb = (m0 + MC <= M) ? MC : (M - m0);
1078  int nb = (n0 + NC <= N) ? NC : (N - n0);
1079 
1080  for (int k0 = 0; k0 < K; k0 += KC) {
1081  int kb = (k0 + KC <= K) ? KC : (K - k0);
1082  int first_k = (k0 == 0);
1083 
1084  for (int m1 = 0; m1 < mb; m1 += mr) {
1085  int mr_actual = (m1 + mr <= mb) ? mr : (mb - m1);
1086 
1087  for (int n1 = 0; n1 < nb; n1 += nr) {
1088  int nr_actual = (n1 + nr <= nb) ? nr : (nb - n1);
1089 
1090  const float *A_tile = &A[(m0 + m1) * K + k0];
1091  const float *B_tile = &B[(n0 + n1) * K + k0];
1092  float *C_tile = &C[(m0 + m1) * N + (n0 + n1)];
1093 
1094  if (mr_actual == mr && nr_actual == nr) {
1095 #if defined(__AVX512F__)
1096  gemm_microkernel_6x32_bt_avx512(kb, A_tile, K, B_tile, K, C_tile, N, first_k);
1097 #else
1098  // Scalar fallback for B-transposed
1099  for (int i = 0; i < mr; i++) {
1100  for (int j = 0; j < nr; j++) {
1101  float sum = first_k ? 0.0f : C_tile[i * N + j];
1102  for (int kk = 0; kk < kb; kk++) {
1103  sum += A_tile[i * K + kk] * B_tile[j * K + kk];
1104  }
1105  C_tile[i * N + j] = sum;
1106  }
1107  }
1108 #endif
1109  } else {
1110  // Edge case
1111  for (int i = 0; i < mr_actual; i++) {
1112  for (int j = 0; j < nr_actual; j++) {
1113  float sum = first_k ? 0.0f : C_tile[i * N + j];
1114  for (int kk = 0; kk < kb; kk++) {
1115  sum += A_tile[i * K + kk] * B_tile[j * K + kk];
1116  }
1117  C_tile[i * N + j] = sum;
1118  }
1119  }
1120  }
1121  }
1122  }
1123  }
1124  }
1125  }
1126 }
#define NC
#define MC

References C, KC, MC, MR, NC, and NR.

Referenced by gemm_microkernel().

◆ gemm_microkernel_packed()

void gemm_microkernel_packed ( const float *  A,
const float *  B,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 840 of file gemm_microkernel.c.

846 {
847  // Use tile-parallel blocked version - scales better on many-core systems
848  gemm_microkernel_blocked(A, B, C, M, N, K);
849 }

References C, and gemm_microkernel_blocked().

Referenced by gemm_microkernel().

◆ gemm_naive_parallel()

void gemm_naive_parallel ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 125 of file gemm_kernels.c.

130 {
131  if (ck_strict_parity_enabled()) {
132  gemm_naive_serial_double(A, B, bias, C, M, N, K);
133  return;
134  }
135 #pragma omp parallel for
136  for (int i = 0; i < M; i++) {
137  for (int j = 0; j < N; j++) {
138  float sum = 0.0f;
139  for (int k = 0; k < K; k++) {
140  sum += A[i * K + k] * B[j * K + k];
141  }
142  float bias_val = bias ? bias[j] : 0.0f;
143  C[i * N + j] = sum + bias_val;
144  }
145  }
146 }

References C, ck_strict_parity_enabled(), and gemm_naive_serial_double().

Referenced by ck_attention_project_head_major_ref(), ck_mlp_swiglu_forward_ref(), ck_qkv_project_head_major_ref(), gemm_avx512_parallel(), and gemm_fine_grained_parallel().

◆ gemm_nn_avx512()

void gemm_nn_avx512 ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 339 of file gemm_kernels.c.

344 {
345  if (ck_strict_parity_enabled()) {
346  gemm_nn_serial_double(A, B, bias, C, M, N, K);
347  return;
348  }
349 #if defined(__AVX512F__)
350  // For gemm_nn, we can't vectorize over K easily since B[k,j] has stride N.
351  // Instead, vectorize over N (output columns) when N >= 16.
352 #pragma omp parallel for
353  for (int i = 0; i < M; i++) {
354  int j = 0;
355  // Process 16 output columns at a time
356  for (; j <= N - 16; j += 16) {
357  __m512 sum_vec = bias ? _mm512_loadu_ps(&bias[j]) : _mm512_setzero_ps();
358  for (int k = 0; k < K; k++) {
359  __m512 a_broadcast = _mm512_set1_ps(A[i * K + k]);
360  __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
361  sum_vec = _mm512_fmadd_ps(a_broadcast, b_vec, sum_vec);
362  }
363  _mm512_storeu_ps(&C[i * N + j], sum_vec);
364  }
365  // Handle remaining columns
366  for (; j < N; j++) {
367  float sum = bias ? bias[j] : 0.0f;
368  for (int k = 0; k < K; k++) {
369  sum += A[i * K + k] * B[k * N + j];
370  }
371  C[i * N + j] = sum;
372  }
373  }
374 #elif defined(__AVX__)
375  // AVX1: vectorize over N (8 columns at a time)
376 #pragma omp parallel for
377  for (int i = 0; i < M; i++) {
378  int j = 0;
379  for (; j <= N - 8; j += 8) {
380  __m256 sum_vec = bias ? _mm256_loadu_ps(&bias[j]) : _mm256_setzero_ps();
381  for (int k = 0; k < K; k++) {
382  __m256 a_broadcast = _mm256_set1_ps(A[i * K + k]);
383  __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
384  __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
385  sum_vec = _mm256_add_ps(sum_vec, prod);
386  }
387  _mm256_storeu_ps(&C[i * N + j], sum_vec);
388  }
389  for (; j < N; j++) {
390  float sum = bias ? bias[j] : 0.0f;
391  for (int k = 0; k < K; k++) {
392  sum += A[i * K + k] * B[k * N + j];
393  }
394  C[i * N + j] = sum;
395  }
396  }
397 #else
398  gemm_nn_parallel(A, B, bias, C, M, N, K);
399 #endif
400 }
static void gemm_nn_serial_double(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:300
void gemm_nn_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:317

References C, ck_strict_parity_enabled(), gemm_nn_parallel(), and gemm_nn_serial_double().

Referenced by fc1_backward_kernel(), and fc2_backward_kernel().

◆ gemm_nn_blocked()

void gemm_nn_blocked ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 402 of file gemm_kernels.c.

407 {
408  if (ck_strict_parity_enabled()) {
409  gemm_nn_serial_double(A, B, bias, C, M, N, K);
410  return;
411  }
412 #if defined(__AVX512F__)
413  const int block_size = 64;
414 #elif defined(__AVX__)
415  const int block_size = 32;
416 #else
417  const int block_size = 32;
418 #endif
419  // Initialize C with bias (parallelized)
420 #pragma omp parallel for
421  for (int i = 0; i < M; i++) {
422  for (int j = 0; j < N; j++) {
423  C[i * N + j] = bias ? bias[j] : 0.0f;
424  }
425  }
426  // Blocked multiply-accumulate (parallelized over M blocks)
427 #pragma omp parallel for
428  for (int ii = 0; ii < M; ii += block_size) {
429  for (int kk = 0; kk < K; kk += block_size) {
430  for (int jj = 0; jj < N; jj += block_size) {
431  int i_end = ck_min(ii + block_size, M);
432  int k_end = ck_min(kk + block_size, K);
433  int j_end = ck_min(jj + block_size, N);
434 
435  for (int i = ii; i < i_end; i++) {
436  for (int k = kk; k < k_end; k++) {
437  float a_val = A[i * K + k];
438 #if defined(__AVX512F__)
439  __m512 a_broadcast = _mm512_set1_ps(a_val);
440  int j;
441  for (j = jj; j <= j_end - 16; j += 16) {
442  __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
443  __m512 c_vec = _mm512_loadu_ps(&C[i * N + j]);
444  c_vec = _mm512_fmadd_ps(a_broadcast, b_vec, c_vec);
445  _mm512_storeu_ps(&C[i * N + j], c_vec);
446  }
447  for (; j < j_end; j++) {
448  C[i * N + j] += a_val * B[k * N + j];
449  }
450 #elif defined(__AVX__)
451  __m256 a_broadcast = _mm256_set1_ps(a_val);
452  int j;
453  for (j = jj; j <= j_end - 8; j += 8) {
454  __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
455  __m256 c_vec = _mm256_loadu_ps(&C[i * N + j]);
456  __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
457  c_vec = _mm256_add_ps(c_vec, prod);
458  _mm256_storeu_ps(&C[i * N + j], c_vec);
459  }
460  for (; j < j_end; j++) {
461  C[i * N + j] += a_val * B[k * N + j];
462  }
463 #else
464  for (int j = jj; j < j_end; j++) {
465  C[i * N + j] += a_val * B[k * N + j];
466  }
467 #endif
468  }
469  }
470  }
471  }
472  }
473 }

References C, ck_min(), ck_strict_parity_enabled(), and gemm_nn_serial_double().

◆ gemm_nn_parallel()

void gemm_nn_parallel ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 317 of file gemm_kernels.c.

322 {
323  if (ck_strict_parity_enabled()) {
324  gemm_nn_serial_double(A, B, bias, C, M, N, K);
325  return;
326  }
327 #pragma omp parallel for
328  for (int i = 0; i < M; i++) {
329  for (int j = 0; j < N; j++) {
330  float sum = bias ? bias[j] : 0.0f;
331  for (int k = 0; k < K; k++) {
332  sum += A[i * K + k] * B[k * N + j];
333  }
334  C[i * N + j] = sum;
335  }
336  }
337 }

References C, ck_strict_parity_enabled(), and gemm_nn_serial_double().

Referenced by gemm_nn_avx512().

◆ gemm_nt_q4_0()

void gemm_nt_q4_0 ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.

Parameters
AInput matrix [M x K], row-major FP32
BWeight matrix in Q4_0 format, [N x K] stored row-major
biasOptional bias [N], NULL if not used
COutput [M x N], row-major FP32
MBatch size (number of tokens)
NOutput dimension (number of rows in B)
KInput dimension

Definition at line 176 of file gemm_kernels_q4_0.c.

181 {
182  const block_q4_0 *blocks = (const block_q4_0 *)B;
183  const int blocks_per_row = K / QK4_0;
184 
185  for (int m = 0; m < M; m++) {
186  const float *a_row = &A[m * K];
187 
188  for (int n = 0; n < N; n++) {
189  float sum = 0.0f;
190 
191  for (int b = 0; b < blocks_per_row; b++) {
192  const block_q4_0 *block = &blocks[n * blocks_per_row + b];
193  const float d = CK_FP16_TO_FP32(block->d);
194  const float *ap = &a_row[b * QK4_0];
195 
196  for (int i = 0; i < QK4_0 / 2; i++) {
197  const uint8_t packed = block->qs[i];
198  const int q0 = (packed & 0x0F) - 8;
199  const int q1 = (packed >> 4) - 8;
200 
201  sum += d * (float)q0 * ap[2 * i + 0];
202  sum += d * (float)q1 * ap[2 * i + 1];
203  }
204  }
205 
206  C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
207  }
208  }
209 }
ck_half d
Definition: ckernel_quant.h:38
uint8_t qs[32/2]
Definition: ckernel_quant.h:39

References C, CK_FP16_TO_FP32, block_q4_0::d, QK4_0, and block_q4_0::qs.

Referenced by ck_gemm_nt_quant().

◆ gemm_nt_q4_1()

void gemm_nt_q4_1 ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

GEMM with transposed Q4_1 weights: C = A @ B^T.

Parameters
AInput activations [M x K], row-major FP32
BWeight matrix in Q4_1 format [N x K], row-major quantized
biasOptional bias [N], NULL if not used
COutput [M x N], row-major FP32
MBatch size (number of tokens)
NOutput dimension
KInput dimension

Definition at line 256 of file gemm_kernels_q4_1.c.

261 {
262  const block_q4_1 *blocks = (const block_q4_1 *)B;
263  const int blocks_per_row = K / QK4_1;
264 
265  for (int m = 0; m < M; m++) {
266  const float *a_row = &A[m * K];
267 
268  for (int n = 0; n < N; n++) {
269  float sum = 0.0f;
270 
271  for (int b = 0; b < blocks_per_row; b++) {
272  const block_q4_1 *block = &blocks[n * blocks_per_row + b];
273  const float d = CK_FP16_TO_FP32(block->d);
274  const float min = CK_FP16_TO_FP32(block->m);
275  const float *ap = &a_row[b * QK4_1];
276 
277  for (int i = 0; i < QK4_1 / 2; i++) {
278  const uint8_t packed = block->qs[i];
279  const int q0 = (packed & 0x0F);
280  const int q1 = (packed >> 4);
281 
282  const float w0 = d * (float)q0 + min;
283  const float w1 = d * (float)q1 + min;
284 
285  sum += w0 * ap[2 * i + 0];
286  sum += w1 * ap[2 * i + 1];
287  }
288  }
289 
290  C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
291  }
292  }
293 }
ck_half m
Definition: ckernel_quant.h:54
ck_half d
Definition: ckernel_quant.h:53
uint8_t qs[32/2]
Definition: ckernel_quant.h:55

References C, CK_FP16_TO_FP32, block_q4_1::d, block_q4_1::m, QK4_1, and block_q4_1::qs.

Referenced by ck_gemm_nt_quant().

◆ gemm_nt_q4_k()

void gemm_nt_q4_k ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 683 of file gemm_kernels_q4k.c.

688 {
689  if (!A || !B || !C) {
690  return;
691  }
692  if (M <= 0 || N <= 0 || K <= 0) {
693  return;
694  }
695 
696  /* gemm_q4_k produces Y as [batch x M_out]. Here:
697  * batch = M (tokens)
698  * M_out = N (output channels) */
699  gemm_q4_k(C, B, A, /*M_out=*/N, /*N_batch=*/M, K);
700 
701  if (!bias) {
702  return;
703  }
704 
705  for (int i = 0; i < M; ++i) {
706  float *row = C + (size_t)i * (size_t)N;
707  for (int j = 0; j < N; ++j) {
708  row[j] += bias[j];
709  }
710  }
711 }
void gemm_q4_k(float *Y, const void *W, const float *X, int M, int N, int K)
Auto-dispatch GEMM based on available SIMD.

References C, and gemm_q4_k().

Referenced by ck_attention_project_head_major_q4_k(), ck_gemm_nt_quant(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_mlp_swiglu_forward_q4_k(), ck_qkv_project_head_major_q4_k(), ck_qkv_project_head_major_token_q4_k(), model_decode_token(), model_layer_0_decode(), model_layer_0_prefill(), model_layer_10_decode(), model_layer_10_prefill(), model_layer_11_decode(), model_layer_11_prefill(), model_layer_12_decode(), model_layer_12_prefill(), model_layer_13_decode(), model_layer_13_prefill(), model_layer_14_decode(), model_layer_14_prefill(), model_layer_15_decode(), model_layer_15_prefill(), model_layer_16_decode(), model_layer_16_prefill(), model_layer_17_decode(), model_layer_17_prefill(), model_layer_18_decode(), model_layer_18_prefill(), model_layer_19_decode(), model_layer_19_prefill(), model_layer_1_decode(), model_layer_1_prefill(), model_layer_20_decode(), model_layer_20_prefill(), model_layer_21_decode(), model_layer_21_prefill(), model_layer_22_decode(), model_layer_22_prefill(), model_layer_23_decode(), model_layer_23_prefill(), model_layer_2_decode(), model_layer_2_prefill(), model_layer_3_decode(), model_layer_3_prefill(), model_layer_4_decode(), model_layer_4_prefill(), model_layer_5_decode(), model_layer_5_prefill(), model_layer_6_decode(), model_layer_6_prefill(), model_layer_7_decode(), model_layer_7_prefill(), model_layer_8_decode(), model_layer_8_prefill(), model_layer_9_decode(), model_layer_9_prefill(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().

◆ gemm_nt_q4_k_q8_k()

void gemm_nt_q4_k_q8_k ( const void *  A_q8,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 295 of file gemm_kernels_q4k_q8k.c.

300 {
301  if (!A_q8 || !B || !C) {
302  return;
303  }
304  if (M <= 0 || N <= 0 || K <= 0) {
305  return;
306  }
307 
308  gemm_q4_k_q8_k(C, B, A_q8, /*M_out=*/N, /*N_batch=*/M, K);
309 
310  if (!bias) {
311  return;
312  }
313 
314  for (int i = 0; i < M; ++i) {
315  float *row = C + (size_t)i * (size_t)N;
316  for (int j = 0; j < N; ++j) {
317  row[j] += bias[j];
318  }
319  }
320 }
void gemm_q4_k_q8_k(float *Y, const void *W, const void *X_q8, int M, int N, int K)

References C, and gemm_q4_k_q8_k().

Referenced by ck_attention_project_head_major_q4_k_q8_k(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_qkv_project_head_major_token_q4_k_q8_k(), gemm_nt_q8_k_mlp_dispatch(), gemm_nt_q8_k_qkv_dispatch(), model_forward_prefill_impl(), and qwen2_0_5b_decode_forward_prefill_impl().

◆ gemm_nt_q5_0()

void gemm_nt_q5_0 ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 831 of file gemm_kernels_q5_0.c.

836 {
837  /* For decode (M=1), use direct GEMV which has AVX optimization */
838  if (M == 1) {
839  /* gemm_q5_0 expects column-major output, but we need row-major
840  * So we call gemv_q5_0 directly for each output element */
841  gemv_q5_0(C, B, A, N, K);
842  if (bias) {
843  for (int n = 0; n < N; n++) {
844  C[n] += bias[n];
845  }
846  }
847  return;
848  }
849 
850  /* For prefill (M>1), use GEMM which dispatches to GEMV with AVX/AVX512 */
851  /* gemm_q5_0 produces Y as [batch x M_out]. Here:
852  * batch = M (tokens)
853  * M_out = N (output channels) */
854  gemm_q5_0(C, B, A, /*M_out=*/N, /*N_batch=*/M, K);
855 
856  if (bias) {
857  for (int m = 0; m < M; m++) {
858  float *row = C + (size_t)m * (size_t)N;
859  for (int n = 0; n < N; n++) {
860  row[n] += bias[n];
861  }
862  }
863  }
864 }
void gemv_q5_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q5_0 weights based on CPU features.
void gemm_q5_0(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q5_0 weights.

References C, gemm_q5_0(), and gemv_q5_0().

Referenced by ck_gemm_nt_quant(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().

◆ gemm_nt_q5_1()

void gemm_nt_q5_1 ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

GEMM with transposed Q5_1 weights: C = A @ B^T.

Parameters
AInput activations [M x K], row-major FP32
BWeight matrix in Q5_1 format [N x K], row-major quantized
biasOptional bias [N], NULL if not used
COutput [M x N], row-major FP32
MBatch size (number of tokens)
NOutput dimension
KInput dimension

Definition at line 309 of file gemm_kernels_q5_1.c.

314 {
315  const block_q5_1 *blocks = (const block_q5_1 *)B;
316  const int blocks_per_row = K / QK5_1;
317 
318  for (int m = 0; m < M; m++) {
319  const float *a_row = &A[m * K];
320 
321  for (int n = 0; n < N; n++) {
322  float sum = 0.0f;
323 
324  for (int b = 0; b < blocks_per_row; b++) {
325  const block_q5_1 *block = &blocks[n * blocks_per_row + b];
326  const float d = CK_FP16_TO_FP32(block->d);
327  const float min = CK_FP16_TO_FP32(block->m);
328  const float *ap = &a_row[b * QK5_1];
329 
330  uint32_t qh;
331  memcpy(&qh, block->qh, sizeof(qh));
332 
333  /* First 16 weights: low nibbles, high bits from qh[0:15] */
334  for (int j = 0; j < QK5_1 / 2; j++) {
335  const int lo = (block->qs[j] & 0x0F);
336  const int hi = ((qh >> j) & 1) << 4;
337  sum += (d * (float)(lo | hi) + min) * ap[j];
338  }
339 
340  /* Second 16 weights: high nibbles, high bits from qh[16:31] */
341  for (int j = 0; j < QK5_1 / 2; j++) {
342  const int lo = (block->qs[j] >> 4);
343  const int hi = ((qh >> (j + 16)) & 1) << 4;
344  sum += (d * (float)(lo | hi) + min) * ap[j + QK5_1 / 2];
345  }
346  }
347 
348  C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
349  }
350  }
351 }
uint8_t qs[32/2]
Definition: ckernel_quant.h:90
uint8_t qh[4]
Definition: ckernel_quant.h:89
ck_half m
Definition: ckernel_quant.h:88
ck_half d
Definition: ckernel_quant.h:87

References C, CK_FP16_TO_FP32, block_q5_1::d, block_q5_1::m, block_q5_1::qh, QK5_1, and block_q5_1::qs.

Referenced by ck_gemm_nt_quant().

◆ gemm_nt_q5_k()

void gemm_nt_q5_k ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 218 of file gemm_kernels_q5_k.c.

223 {
224 #if defined(__AVX512F__)
225  /* TODO: AVX-512 implementation */
226  gemm_nt_q5_k_ref(A, B, bias, C, M, N, K);
227 #elif defined(__AVX2__)
228  /* TODO: AVX-2 implementation */
229  gemm_nt_q5_k_ref(A, B, bias, C, M, N, K);
230 #elif defined(__AVX__)
231  /* TODO: AVX implementation */
232  gemm_nt_q5_k_ref(A, B, bias, C, M, N, K);
233 #elif defined(__SSE4_1__)
234  /* TODO: SSE4.1 implementation */
235  gemm_nt_q5_k_ref(A, B, bias, C, M, N, K);
236 #else
237  gemm_nt_q5_k_ref(A, B, bias, C, M, N, K);
238 #endif
239 }
void gemm_nt_q5_k_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)

References C, and gemm_nt_q5_k_ref().

◆ gemm_nt_q6_k()

void gemm_nt_q6_k ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 212 of file gemm_kernels_q6k.c.

217 {
218  if (!A || !B || !C) {
219  return;
220  }
221  if (M <= 0 || N <= 0 || K <= 0) {
222  return;
223  }
224 
225  /* gemm_q6_k produces Y as [batch x M_out] where:
226  * batch = M (tokens)
227  * M_out = N (output channels) */
228  gemm_q6_k(C, B, A, /*M_out=*/N, /*N_batch=*/M, K);
229 
230  if (!bias) {
231  return;
232  }
233 
234  for (int i = 0; i < M; ++i) {
235  float *row = C + (size_t)i * (size_t)N;
236  for (int j = 0; j < N; ++j) {
237  row[j] += bias[j];
238  }
239  }
240 }
void gemm_q6_k(float *Y, const void *W, const float *X, int M, int N, int K)

References C, and gemm_q6_k().

Referenced by ck_gemm_nt_quant(), gemm_nt_q6_k_ref(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().

◆ gemm_nt_q6_k_q8_k()

void gemm_nt_q6_k_q8_k ( const void *  A_q8,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.

This is the typical inference pattern:

  • A: Activations in Q8_K format [M x K]
  • B: Weights in Q6_K format [N x K]
  • C: Output [M x N]
Parameters
A_q8Input activations in Q8_K format
BWeight matrix in Q6_K format
biasOptional bias vector [N]
COutput matrix
MBatch size (number of tokens)
NOutput dimension
KInput dimension

Definition at line 1144 of file gemm_kernels_q6k_q8k.c.

1149 {
1150  if (!A_q8 || !B || !C) {
1151  return;
1152  }
1153  if (M <= 0 || N <= 0 || K <= 0) {
1154  return;
1155  }
1156 
1157  gemm_q6_k_q8_k(C, B, A_q8, /*M_out=*/N, /*N_batch=*/M, K);
1158 
1159  if (!bias) {
1160  return;
1161  }
1162 
1163  for (int i = 0; i < M; ++i) {
1164  float *row = C + (size_t)i * (size_t)N;
1165  for (int j = 0; j < N; ++j) {
1166  row[j] += bias[j];
1167  }
1168  }
1169 }
void gemm_q6_k_q8_k(float *Y, const void *W, const void *X_q8, int M, int N, int K)
GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K.

References C, and gemm_q6_k_q8_k().

Referenced by gemm_nt_q8_k_mlp_dispatch(), and gemm_nt_q8_k_qkv_dispatch().

◆ gemm_nt_q8_0()

void gemm_nt_q8_0 ( const float *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.

Parameters
AInput matrix [M x K], row-major FP32
BWeight matrix in Q8_0 format, [N x K] stored row-major
biasOptional bias [N], NULL if not used
COutput [M x N], row-major FP32
MBatch size (number of tokens)
NOutput dimension (number of rows in B)
KInput dimension

Definition at line 681 of file gemm_kernels_q8_0.c.

686 {
687  /* Use GEMV dispatch which selects AVX/SSE/scalar based on CPU */
688  for (int m = 0; m < M; m++) {
689  gemv_q8_0(&C[m * N], B, &A[m * K], N, K);
690  if (bias) {
691  for (int n = 0; n < N; n++) C[m * N + n] += bias[n];
692  }
693  }
694  return;
695 
696  const block_q8_0 *blocks = (const block_q8_0 *)B;
697  const int blocks_per_row = K / QK8_0;
698 
699  for (int m = 0; m < M; m++) {
700  const float *a_row = &A[m * K];
701 
702  for (int n = 0; n < N; n++) {
703  float sum = 0.0f;
704 
705  for (int b = 0; b < blocks_per_row; b++) {
706  const block_q8_0 *block = &blocks[n * blocks_per_row + b];
707  const float d = CK_FP16_TO_FP32(block->d);
708  const float *ap = &a_row[b * QK8_0];
709 
710  for (int i = 0; i < QK8_0; i++) {
711  sum += d * (float)block->qs[i] * ap[i];
712  }
713  }
714 
715  C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
716  }
717  }
718 }
void gemv_q8_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV for Q8_0 weights based on CPU features.

References C, CK_FP16_TO_FP32, block_q8_0::d, gemv_q8_0(), QK8_0, and block_q8_0::qs.

Referenced by ck_gemm_nt_quant(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_forward_prefill_impl(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().

◆ gemm_nt_q8_0_q8_0()

void gemm_nt_q8_0_q8_0 ( const void *  A,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

gemm_nt_q8_0_q8_0 with optional bias (matches header signature)

C[m,n] = A[m,K] @ B[n,K]^T + bias[n]

Definition at line 582 of file gemm_batch_int8.c.

588 {
589  /* First compute GEMM */
590 #if defined(__AVX512VNNI__)
591  gemm_nt_q8_0_q8_0_vnni(A, B, C, M, N, K);
592 #elif defined(__AVX512F__)
593  gemm_nt_q8_0_q8_0_avx512(A, B, C, M, N, K);
594 #elif defined(__AVX2__)
595  gemm_nt_q8_0_q8_0_avx2(A, B, C, M, N, K);
596 #elif defined(__AVX__)
597  gemm_nt_q8_0_q8_0_avx(A, B, C, M, N, K);
598 #else
599  gemm_nt_q8_0_q8_0_ref(A, B, C, M, N, K);
600 #endif
601 
602  /* Add bias if provided */
603  if (bias != NULL) {
604  for (int m = 0; m < M; m++) {
605  for (int n = 0; n < N; n++) {
606  C[(size_t)m * N + n] += bias[n];
607  }
608  }
609  }
610 }
void gemm_nt_q8_0_q8_0_ref(const void *A, const void *B, float *C, int M, int N, int K)
Scalar reference: gemm_nt_q8_0_q8_0.

References C, and gemm_nt_q8_0_q8_0_ref().

Referenced by gemm_nt_q8_0_dispatch(), and gemm_nt_q8_0_mlp_dispatch().

◆ gemm_q4_k()

void gemm_q4_k ( float *  Y,
const void *  W,
const float *  X,
int  M,
int  N,
int  K 
)

Auto-dispatch GEMM based on available SIMD.

Definition at line 461 of file gemm_kernels_q4k.c.

465 {
466  /* Use reference implementation for correctness
467  * TODO: Fix AVX-512 version to match llama.cpp layout */
468  gemm_q4_k_ref(Y, W, X, M, N, K);
469 }
void gemm_q4_k_ref(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q4_K weights (scalar reference)

References gemm_q4_k_ref().

Referenced by gemm_nt_q4_k().

◆ gemm_q4_k_q8_k()

void gemm_q4_k_q8_k ( float *  Y,
const void *  W,
const void *  X_q8,
int  M,
int  N,
int  K 
)

Definition at line 277 of file gemm_kernels_q4k_q8k.c.

281 {
282  if (!Y || !W || !X_q8 || M <= 0 || N <= 0 || K <= 0) {
283  return;
284  }
285 
286  const block_q8_K *X = (const block_q8_K *)X_q8;
287  const int blocks_per_vec = K / QK_K;
288 
289  for (int n = 0; n < N; ++n) {
290  const block_q8_K *x_row = X + (size_t)n * (size_t)blocks_per_vec;
291  gemv_q4_k_q8_k(&Y[n * M], W, x_row, M, K);
292  }
293 }
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)

References gemv_q4_k_q8_k(), and QK_K.

Referenced by gemm_nt_q4_k_q8_k().

◆ gemm_q6_k()

void gemm_q6_k ( float *  Y,
const void *  W,
const float *  X,
int  M,
int  N,
int  K 
)

Definition at line 195 of file gemm_kernels_q6k.c.

199 {
200  if (!Y || !W || !X) {
201  return;
202  }
203  if (M <= 0 || N <= 0 || K <= 0) {
204  return;
205  }
206 
207  for (int n = 0; n < N; ++n) {
208  gemv_q6_k(&Y[n * M], W, &X[n * K], M, K);
209  }
210 }
void gemv_q6_k(float *y, const void *W, const float *x, int M, int K)

References gemv_q6_k().

Referenced by gemm_nt_q6_k().

◆ gemm_q6_k_q8_k()

void gemm_q6_k_q8_k ( float *  Y,
const void *  W,
const void *  X_q8,
int  M,
int  N,
int  K 
)

GEMM: Y = W @ X^T where W is Q6_K and X is Q8_K.

Parameters
YOutput matrix [N x M] in row-major
WWeight matrix in Q6_K format [M x K]
X_q8Input matrix in Q8_K format [N x K]
MNumber of output rows (output dim)
NNumber of input vectors (batch size)
KInput dimension

Definition at line 1110 of file gemm_kernels_q6k_q8k.c.

1114 {
1115  if (!Y || !W || !X_q8 || M <= 0 || N <= 0 || K <= 0) {
1116  return;
1117  }
1118 
1119  const block_q8_K *X = (const block_q8_K *)X_q8;
1120  const int blocks_per_vec = K / QK_K;
1121 
1122  for (int n = 0; n < N; ++n) {
1123  const block_q8_K *x_row = X + (size_t)n * (size_t)blocks_per_vec;
1124  gemv_q6_k_q8_k(&Y[n * M], W, x_row, M, K);
1125  }
1126 }
void gemv_q6_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
GEMV: y = W @ x where W is Q6_K and x is Q8_K.

References gemv_q6_k_q8_k(), and QK_K.

Referenced by gemm_nt_q6_k_q8_k().

◆ gemm_swiglu_fused()

void gemm_swiglu_fused ( const float *  x,
const float *  W_gate,
const float *  W_up,
const float *  b_gate,
const float *  b_up,
float *  output,
int  M,
int  N,
int  K 
)

Definition at line 241 of file gemm_fused_kernels.c.

248 {
249 #if defined(__AVX__)
250 #pragma omp parallel for
251  for (int i = 0; i < M; i++) {
252  const float *x_row = &x[i * K];
253  float *out_row = &output[i * N];
254 
255  for (int j = 0; j < N; j++) {
256  const float *w_gate_row = &W_gate[j * K];
257  const float *w_up_row = &W_up[j * K];
258 
259  // Compute both dot products in parallel using SIMD
260  __m256 gate_vec = _mm256_setzero_ps();
261  __m256 up_vec = _mm256_setzero_ps();
262 
263  int k;
264  for (k = 0; k <= K - 8; k += 8) {
265  __m256 x_vec = _mm256_loadu_ps(&x_row[k]);
266  __m256 wg_vec = _mm256_loadu_ps(&w_gate_row[k]);
267  __m256 wu_vec = _mm256_loadu_ps(&w_up_row[k]);
268 
269  // gate += x * W_gate
270  gate_vec = _mm256_add_ps(gate_vec, _mm256_mul_ps(x_vec, wg_vec));
271  // up += x * W_up
272  up_vec = _mm256_add_ps(up_vec, _mm256_mul_ps(x_vec, wu_vec));
273  }
274 
275  // Horizontal sum
276  float gate = hsum256_ps_fused(gate_vec);
277  float up = hsum256_ps_fused(up_vec);
278 
279  // Scalar remainder
280  for (; k < K; k++) {
281  gate += x_row[k] * w_gate_row[k];
282  up += x_row[k] * w_up_row[k];
283  }
284 
285  // Add biases
286  if (b_gate) gate += b_gate[j];
287  if (b_up) up += b_up[j];
288 
289  // SwiGLU: SiLU(gate) * up = gate * sigmoid(gate) * up
290  float sig = 1.0f / (1.0f + expf(-gate));
291  out_row[j] = gate * sig * up;
292  }
293  }
294 #else
295  // Scalar fallback
296 #pragma omp parallel for
297  for (int i = 0; i < M; i++) {
298  for (int j = 0; j < N; j++) {
299  float gate = 0.0f;
300  float up = 0.0f;
301 
302  for (int k = 0; k < K; k++) {
303  gate += x[i * K + k] * W_gate[j * K + k];
304  up += x[i * K + k] * W_up[j * K + k];
305  }
306 
307  if (b_gate) gate += b_gate[j];
308  if (b_up) up += b_up[j];
309 
310  // SwiGLU: SiLU(gate) * up
311  float sig = 1.0f / (1.0f + expf(-gate));
312  output[i * N + j] = gate * sig * up;
313  }
314  }
315 #endif
316 }

References hsum256_ps_fused().

Referenced by ck_mlp_swiglu_forward_fused_token().

◆ gemm_tn_avx512()

void gemm_tn_avx512 ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 521 of file gemm_kernels.c.

526 {
527  if (ck_strict_parity_enabled()) {
528  gemm_tn_serial_double(A, B, bias, C, M, N, K);
529  return;
530  }
531 #if defined(__AVX512F__)
532  // Vectorize over N (output columns)
533 #pragma omp parallel for
534  for (int i = 0; i < M; i++) {
535  int j = 0;
536  for (; j <= N - 16; j += 16) {
537  __m512 sum_vec = bias ? _mm512_loadu_ps(&bias[j]) : _mm512_setzero_ps();
538  for (int k = 0; k < K; k++) {
539  __m512 a_broadcast = _mm512_set1_ps(A[k * M + i]);
540  __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
541  sum_vec = _mm512_fmadd_ps(a_broadcast, b_vec, sum_vec);
542  }
543  _mm512_storeu_ps(&C[i * N + j], sum_vec);
544  }
545  for (; j < N; j++) {
546  float sum = bias ? bias[j] : 0.0f;
547  for (int k = 0; k < K; k++) {
548  sum += A[k * M + i] * B[k * N + j];
549  }
550  C[i * N + j] = sum;
551  }
552  }
553 #elif defined(__AVX__)
554  // AVX1: vectorize over N (8 columns at a time)
555 #pragma omp parallel for
556  for (int i = 0; i < M; i++) {
557  int j = 0;
558  for (; j <= N - 8; j += 8) {
559  __m256 sum_vec = bias ? _mm256_loadu_ps(&bias[j]) : _mm256_setzero_ps();
560  for (int k = 0; k < K; k++) {
561  __m256 a_broadcast = _mm256_set1_ps(A[k * M + i]);
562  __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
563  __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
564  sum_vec = _mm256_add_ps(sum_vec, prod);
565  }
566  _mm256_storeu_ps(&C[i * N + j], sum_vec);
567  }
568  for (; j < N; j++) {
569  float sum = bias ? bias[j] : 0.0f;
570  for (int k = 0; k < K; k++) {
571  sum += A[k * M + i] * B[k * N + j];
572  }
573  C[i * N + j] = sum;
574  }
575  }
576 #else
577  gemm_tn_parallel(A, B, bias, C, M, N, K);
578 #endif
579 }
void gemm_tn_parallel(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:499
static void gemm_tn_serial_double(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:481

References C, ck_strict_parity_enabled(), gemm_tn_parallel(), and gemm_tn_serial_double().

Referenced by fc1_backward_kernel(), and fc2_backward_kernel().

◆ gemm_tn_blocked()

void gemm_tn_blocked ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 581 of file gemm_kernels.c.

586 {
587  if (ck_strict_parity_enabled()) {
588  gemm_tn_serial_double(A, B, bias, C, M, N, K);
589  return;
590  }
591 #if defined(__AVX512F__)
592  const int block_size = 64;
593 #elif defined(__AVX__)
594  const int block_size = 32;
595 #else
596  const int block_size = 32;
597 #endif
598  // Initialize C with bias (parallelized)
599 #pragma omp parallel for
600  for (int i = 0; i < M; i++) {
601  for (int j = 0; j < N; j++) {
602  C[i * N + j] = bias ? bias[j] : 0.0f;
603  }
604  }
605  // Blocked multiply-accumulate (parallelized over M blocks)
606 #pragma omp parallel for
607  for (int ii = 0; ii < M; ii += block_size) {
608  for (int kk = 0; kk < K; kk += block_size) {
609  for (int jj = 0; jj < N; jj += block_size) {
610  int i_end = ck_min(ii + block_size, M);
611  int k_end = ck_min(kk + block_size, K);
612  int j_end = ck_min(jj + block_size, N);
613 
614  for (int k = kk; k < k_end; k++) {
615  for (int i = ii; i < i_end; i++) {
616  float a_val = A[k * M + i];
617 #if defined(__AVX512F__)
618  __m512 a_broadcast = _mm512_set1_ps(a_val);
619  int j;
620  for (j = jj; j <= j_end - 16; j += 16) {
621  __m512 b_vec = _mm512_loadu_ps(&B[k * N + j]);
622  __m512 c_vec = _mm512_loadu_ps(&C[i * N + j]);
623  c_vec = _mm512_fmadd_ps(a_broadcast, b_vec, c_vec);
624  _mm512_storeu_ps(&C[i * N + j], c_vec);
625  }
626  for (; j < j_end; j++) {
627  C[i * N + j] += a_val * B[k * N + j];
628  }
629 #elif defined(__AVX__)
630  __m256 a_broadcast = _mm256_set1_ps(a_val);
631  int j;
632  for (j = jj; j <= j_end - 8; j += 8) {
633  __m256 b_vec = _mm256_loadu_ps(&B[k * N + j]);
634  __m256 c_vec = _mm256_loadu_ps(&C[i * N + j]);
635  __m256 prod = _mm256_mul_ps(a_broadcast, b_vec);
636  c_vec = _mm256_add_ps(c_vec, prod);
637  _mm256_storeu_ps(&C[i * N + j], c_vec);
638  }
639  for (; j < j_end; j++) {
640  C[i * N + j] += a_val * B[k * N + j];
641  }
642 #else
643  for (int j = jj; j < j_end; j++) {
644  C[i * N + j] += a_val * B[k * N + j];
645  }
646 #endif
647  }
648  }
649  }
650  }
651  }
652 }

References C, ck_min(), ck_strict_parity_enabled(), and gemm_tn_serial_double().

◆ gemm_tn_parallel()

void gemm_tn_parallel ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 499 of file gemm_kernels.c.

504 {
505  if (ck_strict_parity_enabled()) {
506  gemm_tn_serial_double(A, B, bias, C, M, N, K);
507  return;
508  }
509 #pragma omp parallel for
510  for (int i = 0; i < M; i++) {
511  for (int j = 0; j < N; j++) {
512  float sum = bias ? bias[j] : 0.0f;
513  for (int k = 0; k < K; k++) {
514  sum += A[k * M + i] * B[k * N + j];
515  }
516  C[i * N + j] = sum;
517  }
518  }
519 }

References C, ck_strict_parity_enabled(), and gemm_tn_serial_double().

Referenced by gemm_tn_avx512().

◆ gemv_fused_q5_0_bias_dispatch()

void gemv_fused_q5_0_bias_dispatch ( float *  y,
const void *  W,
const float *  x,
const float *  bias,
int  M,
int  K 
)

Definition at line 508 of file gemv_fused_quant_bias.c.

515 {
516 #if defined(__AVX__)
517  gemv_fused_q5_0_bias_avx(y, W, x, bias, M, K);
518 #else
519  gemv_fused_q5_0_bias(y, W, x, bias, M, K);
520 #endif
521 }
void gemv_fused_q5_0_bias(float *y, const void *W, const float *x, const float *bias, int M, int K)

References gemv_fused_q5_0_bias().

◆ gemv_fused_q8_0_bias_dispatch()

void gemv_fused_q8_0_bias_dispatch ( float *  y,
const void *  W,
const float *  x,
const float *  bias,
int  M,
int  K 
)

Definition at line 523 of file gemv_fused_quant_bias.c.

530 {
531 #if defined(__AVX__)
532  gemv_fused_q8_0_bias_avx(y, W, x, bias, M, K);
533 #else
534  gemv_fused_q8_0_bias(y, W, x, bias, M, K);
535 #endif
536 }
void gemv_fused_q8_0_bias(float *y, const void *W, const float *x, const float *bias, int M, int K)

References gemv_fused_q8_0_bias().

◆ gemv_q4_0()

void gemv_q4_0 ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Auto-dispatch GEMV.

Definition at line 132 of file gemm_kernels_q4_0.c.

136 {
137 #ifdef __AVX512F__
138  gemv_q4_0_avx512(y, W, x, M, K);
139 #else
140  gemv_q4_0_ref(y, W, x, M, K);
141 #endif
142 }
void gemv_q4_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q4_0 weights (scalar reference)

References gemv_q4_0_ref().

Referenced by dot_q4_0(), and gemm_q4_0().

◆ gemv_q4_k()

void gemv_q4_k ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Auto-dispatch GEMV based on available SIMD.

Definition at line 285 of file gemm_kernels_q4k.c.

289 {
290 #ifdef __AVX512F__
291  gemv_q4_k_avx512(y, W, x, M, K);
292 #elif defined(__AVX__)
293  gemv_q4_k_avx(y, W, x, M, K);
294 #else
295  gemv_q4_k_ref(y, W, x, M, K);
296 #endif
297 }
void gemv_q4_k_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q4_K weights (scalar reference)

References gemv_q4_k_ref().

Referenced by attention_mlp_fused_q4k(), dot_q4_k(), gemm_q4_k_ref(), layer_fused_attn_mlp_qkv_q4k(), and rmsnorm_qkv_q4k_fused().

◆ gemv_q4_k_q8_k()

void gemv_q4_k_q8_k ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Definition at line 239 of file gemm_kernels_q4k_q8k.c.

243 {
244 #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
245  /* VNNI: Best for decode (single token) - INT8 dot product acceleration */
246  gemv_q4_k_q8_k_vnni(y, W, x_q8, M, K);
247 #elif defined(__AVX2__)
248  gemv_q4_k_q8_k_avx2(y, W, x_q8, M, K);
249 #elif defined(__AVX__)
250  /* AVX version uses maddubs_epi16 (more efficient than SSE) */
251  gemv_q4_k_q8_k_avx(y, W, x_q8, M, K);
252 #elif defined(__SSE4_1__)
253  gemv_q4_k_q8_k_sse(y, W, x_q8, M, K);
254 #else
255  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
256 #endif
257 }
void gemv_q4_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_vnni(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K)

References gemv_q4_k_q8_k_avx(), gemv_q4_k_q8_k_avx2(), gemv_q4_k_q8_k_ref(), gemv_q4_k_q8_k_sse(), and gemv_q4_k_q8_k_vnni().

Referenced by model_decode_token(), model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().

◆ gemv_q4_k_q8_k_parallel()

void gemv_q4_k_q8_k_parallel ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K,
int  ith,
int  nth 
)

Definition at line 206 of file gemm_kernels_q4k_q8k.c.

211 {
212  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
213  return;
214  }
215  if (ith < 0 || nth <= 0 || ith >= nth) {
216  return;
217  }
218 
219  /* Compute row range for this thread */
220  const int dr = (M + nth - 1) / nth;
221  const int r0 = dr * ith;
222  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
223 
224  if (r0 >= M) {
225  return; /* This thread has no work */
226  }
227 
228  const block_q4_K *blocks = (const block_q4_K *)W;
229  const block_q8_K *x = (const block_q8_K *)x_q8;
230  const int blocks_per_row = K / QK_K;
231 
232  /* Only process rows [r0, r1) */
233  for (int row = r0; row < r1; ++row) {
234  const block_q4_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
235  y[row] = dot_q4_k_q8_k_ref(w_row, x, K);
236  }
237 }
static float dot_q4_k_q8_k_ref(const block_q4_K *w, const block_q8_K *x, int k)

References dot_q4_k_q8_k_ref(), and QK_K.

◆ gemv_q4_k_q8_k_parallel_simd()

void gemv_q4_k_q8_k_parallel_simd ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K,
int  ith,
int  nth 
)

Definition at line 263 of file gemm_kernels_q4k_avx.c.

268 {
269  /* Fall back to reference parallel version */
270  gemv_q4_k_q8_k_parallel(y, W, x_q8, M, K, ith, nth);
271 }
void gemv_q4_k_q8_k_parallel(float *y, const void *W, const void *x_q8, int M, int K, int ith, int nth)

References gemv_q4_k_q8_k_parallel().

Referenced by decode_layer_parallel(), mlp_parallel(), and qkv_projection_parallel().

◆ gemv_q4_k_q8_k_ref()

void gemv_q4_k_q8_k_ref ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Definition at line 177 of file gemm_kernels_q4k_q8k.c.

181 {
182  if (!y || !W || !x_q8 || M <= 0 || K <= 0) {
183  return;
184  }
185 
186  const block_q4_K *blocks = (const block_q4_K *)W;
187  const block_q8_K *x = (const block_q8_K *)x_q8;
188  const int blocks_per_row = K / QK_K;
189 
190  for (int row = 0; row < M; ++row) {
191  const block_q4_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
192  y[row] = dot_q4_k_q8_k_ref(w_row, x, K);
193  }
194 }

References dot_q4_k_q8_k_ref(), and QK_K.

◆ gemv_q5_0()

void gemv_q5_0 ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Auto-dispatch GEMV for Q5_0 weights based on CPU features.

Dispatch priority (best available):

  1. AVX-512 (512-bit vectors) - Intel Skylake-X+
  2. AVX2+FMA (256-bit vectors) - Intel Haswell+
  3. AVX (256-bit vectors) - Intel Sandy Bridge+
  4. SSE4.1 (128-bit vectors) - Intel Nehalem+
  5. Reference (scalar) - Fallback

Uses ck_features.h for standardized feature detection.

Parameters
yOutput vector [M]
WWeight matrix in Q5_0 format [M x K]
xInput vector [K]
MNumber of output rows
KNumber of input columns (hidden dimension)

Definition at line 547 of file gemm_kernels_q5_0.c.

551 {
552 // Dispatch order: AVX512 > AVX2 > AVX > SSE > ref
553 #if defined(__AVX512F__)
554  gemv_q5_0_avx512(y, W, x, M, K);
555 #elif defined(__AVX2__)
556  gemv_q5_0_avx2(y, W, x, M, K);
557 #elif defined(__AVX__)
558  gemv_q5_0_avx(y, W, x, M, K);
559 #elif defined(__SSE4_1__)
560  gemv_q5_0_sse_v2(y, W, x, M, K);
561 #else
562  gemv_q5_0_ref(y, W, x, M, K);
563 #endif
564 }
void gemv_q5_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q5_0 weights (scalar reference)

References gemv_q5_0_ref().

◆ gemv_q5_0_parallel()

void gemv_q5_0_parallel ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K,
int  ith,
int  nth 
)

Parallel reference GEMV for Q5_0 × FP32.

Definition at line 576 of file gemm_kernels_q5_0.c.

581 {
582  if (!y || !W || !x || M <= 0 || K <= 0) return;
583  if (ith < 0 || nth <= 0 || ith >= nth) return;
584 
585  const int dr = (M + nth - 1) / nth;
586  const int r0 = dr * ith;
587  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
588 
589  if (r0 >= M) return;
590 
591  const block_q5_0 *blocks = (const block_q5_0 *)W;
592  const int blocks_per_row = K / QK5_0;
593 
594  for (int row = r0; row < r1; row++) {
595  float sum = 0.0f;
596  for (int b = 0; b < blocks_per_row; b++) {
597  const block_q5_0 *block = &blocks[row * blocks_per_row + b];
598  const float d = CK_FP16_TO_FP32(block->d);
599  const float *xp = &x[b * QK5_0];
600 
601  uint32_t qh;
602  memcpy(&qh, block->qh, sizeof(qh));
603 
604  for (int j = 0; j < QK5_0 / 2; j++) {
605  const uint8_t packed = block->qs[j];
606  const int lo = (packed & 0x0F);
607  const int hi = (packed >> 4);
608  const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
609  const int xh_1 = ((qh >> (j + 12))) & 0x10;
610  const int w0 = (lo | xh_0) - 16;
611  const int w1 = (hi | xh_1) - 16;
612  sum += d * (w0 * xp[j] + w1 * xp[j + QK5_0/2]);
613  }
614  }
615  y[row] = sum;
616  }
617 }
ck_half d
Definition: ckernel_quant.h:70
uint8_t qh[4]
Definition: ckernel_quant.h:71
uint8_t qs[32/2]
Definition: ckernel_quant.h:72

References CK_FP16_TO_FP32, block_q5_0::d, block_q5_0::qh, QK5_0, and block_q5_0::qs.

Referenced by gemv_q5_0_parallel_simd().

◆ gemv_q5_0_parallel_simd()

void gemv_q5_0_parallel_simd ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K,
int  ith,
int  nth 
)

Parallel SIMD GEMV for Q5_0 × FP32 with prefetching.

Definition at line 622 of file gemm_kernels_q5_0.c.

627 {
628  if (!y || !W || !x || M <= 0 || K <= 0) return;
629  if (ith < 0 || nth <= 0 || ith >= nth) return;
630 
631  const int dr = (M + nth - 1) / nth;
632  const int r0 = dr * ith;
633  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
634 
635  if (r0 >= M) return;
636 
637  const block_q5_0 *blocks = (const block_q5_0 *)W;
638  const int blocks_per_row = K / QK5_0;
639 
640 #if defined(__AVX__) || defined(__SSE4_1__)
641  /* Prefetch first few rows */
642  const int PREFETCH_ROWS = 4;
643  for (int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
644  const char *row_ptr = (const char *)(blocks + (r0 + p) * blocks_per_row);
645  _mm_prefetch(row_ptr, _MM_HINT_T0);
646  _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
647  }
648 
649  for (int row = r0; row < r1; ++row) {
650  /* Prefetch rows ahead */
651  if (row + PREFETCH_ROWS < r1) {
652  const char *prefetch_ptr = (const char *)(blocks + (row + PREFETCH_ROWS) * blocks_per_row);
653  _mm_prefetch(prefetch_ptr, _MM_HINT_T0);
654  _mm_prefetch(prefetch_ptr + 64, _MM_HINT_T0);
655  }
656 
657  /* Use SIMD dot product for this row */
658 #if defined(__AVX512F__)
659  /* Call single-row AVX512 implementation */
660  gemv_q5_0_avx512(&y[row], (const char *)blocks + row * blocks_per_row * sizeof(block_q5_0), x, 1, K);
661 #elif defined(__AVX2__)
662  gemv_q5_0_avx2(&y[row], (const char *)blocks + row * blocks_per_row * sizeof(block_q5_0), x, 1, K);
663 #elif defined(__AVX__)
664  gemv_q5_0_avx(&y[row], (const char *)blocks + row * blocks_per_row * sizeof(block_q5_0), x, 1, K);
665 #else
666  gemv_q5_0_sse_v2(&y[row], (const char *)blocks + row * blocks_per_row * sizeof(block_q5_0), x, 1, K);
667 #endif
668  }
669 #else
670  /* Fallback to reference parallel */
671  gemv_q5_0_parallel(y, W, x, M, K, ith, nth);
672 #endif
673 }
void gemv_q5_0_parallel(float *y, const void *W, const float *x, int M, int K, int ith, int nth)
Parallel reference GEMV for Q5_0 × FP32.

References gemv_q5_0_parallel(), and QK5_0.

◆ gemv_q5_0_q8_0()

void gemv_q5_0_q8_0 ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Matrix-vector multiply with Q5_0 weights and Q8_0 input.

Parameters
yOutput vector [M]
WWeight matrix in Q5_0 format [M x K]
x_q8Input vector in Q8_0 format [K]
MNumber of output rows
KNumber of columns (must be multiple of 32)

Definition at line 1529 of file gemm_kernels_q5_0.c.

1533 {
1534  const block_q5_0 *w_blocks = (const block_q5_0 *)W;
1535  const block_q8_0 *x_blocks = (const block_q8_0 *)x_q8;
1536  const int blocks_per_row = K / QK5_0;
1537 
1538  for (int row = 0; row < M; row++) {
1539  vec_dot_q5_0_q8_0(K, &y[row],
1540  &w_blocks[row * blocks_per_row],
1541  x_blocks);
1542  }
1543 }
void vec_dot_q5_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q5_0 x Q8_0.

References QK5_0, and vec_dot_q5_0_q8_0().

◆ gemv_q5_1()

void gemv_q5_1 ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Auto-dispatch GEMV.

Definition at line 184 of file gemm_kernels_q5_1.c.

188 {
189 #ifdef __AVX512F__
190  gemv_q5_1_avx512(y, W, x, M, K);
191 #else
192  gemv_q5_1_ref(y, W, x, M, K);
193 #endif
194 }
void gemv_q5_1_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q5_1 weights (scalar reference)

References gemv_q5_1_ref().

◆ gemv_q5_k()

void gemv_q5_k ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Definition at line 199 of file gemm_kernels_q5_k.c.

200 {
201 #if defined(__AVX512F__)
202  /* TODO: AVX-512 implementation */
203  gemv_q5_k_ref(y, W, x, M, K);
204 #elif defined(__AVX2__)
205  /* TODO: AVX-2 implementation */
206  gemv_q5_k_ref(y, W, x, M, K);
207 #elif defined(__AVX__)
208  /* TODO: AVX implementation */
209  gemv_q5_k_ref(y, W, x, M, K);
210 #elif defined(__SSE4_1__)
211  /* TODO: SSE4.1 implementation */
212  gemv_q5_k_ref(y, W, x, M, K);
213 #else
214  gemv_q5_k_ref(y, W, x, M, K);
215 #endif
216 }
void gemv_q5_k_ref(float *y, const void *W, const float *x, int M, int K)

References gemv_q5_k_ref().

◆ gemv_q6_k()

void gemv_q6_k ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Definition at line 169 of file gemm_kernels_q6k.c.

173 {
174  if (!y || !W || !x) {
175  return;
176  }
177  if (M <= 0 || K <= 0) {
178  return;
179  }
180  // TEMPORARILY DISABLE NEW AVX KERNELS - USE REFERENCE ONLY
181 
182  const block_q6_K *blocks = (const block_q6_K *)W;
183  const int blocks_per_row = K / QK_K;
184 
185  for (int row = 0; row < M; ++row) {
186  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
187 #if defined(__AVX__) && !defined(__AVX512F__)
188  y[row] = dot_q6_k_avx(w_row, x, K);
189 #else
190  y[row] = dot_q6_k_ref(w_row, x, K);
191 #endif
192  }
193 }
static float dot_q6_k_ref(const block_q6_K *w, const float *x, int K)

References dot_q6_k_ref(), and QK_K.

Referenced by gemm_q6_k().

◆ gemv_q6_k_q8_k()

void gemv_q6_k_q8_k ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

GEMV: y = W @ x where W is Q6_K and x is Q8_K.

Definition at line 980 of file gemm_kernels_q6k_q8k.c.

984 {
985  /* AVX-512 uses same algorithm as AVX2 (matches llama.cpp) */
986 #if defined(__AVX512F__) && defined(__AVX512BW__)
987  gemv_q6_k_q8_k_avx512(y, W, x_q8, M, K);
988 #elif defined(__AVX2__)
989  gemv_q6_k_q8_k_avx2(y, W, x_q8, M, K);
990 #elif defined(__AVX__)
991  gemv_q6_k_q8_k_avx(y, W, x_q8, M, K);
992 #elif defined(__SSSE3__)
993  gemv_q6_k_q8_k_sse(y, W, x_q8, M, K);
994 #else
995  gemv_q6_k_q8_k_ref(y, W, x_q8, M, K);
996 #endif
997 }
void gemv_q6_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_sse(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q6_k_q8_k_avx512(float *y, const void *W, const void *x_q8, int M, int K)

References gemv_q6_k_q8_k_avx(), gemv_q6_k_q8_k_avx2(), gemv_q6_k_q8_k_avx512(), gemv_q6_k_q8_k_ref(), and gemv_q6_k_q8_k_sse().

◆ gemv_q6_k_q8_k_parallel()

void gemv_q6_k_q8_k_parallel ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K,
int  ith,
int  nth 
)

Parallel reference GEMV for Q6_K × Q8_K.

Caller provides ith (thread index) and nth (total threads). Each thread processes rows [r0, r1).

Definition at line 1014 of file gemm_kernels_q6k_q8k.c.

1019 {
1020  if (!y || !W || !x_q8 || M <= 0 || K <= 0) return;
1021  if (ith < 0 || nth <= 0 || ith >= nth) return;
1022 
1023  /* Compute row range for this thread */
1024  const int dr = (M + nth - 1) / nth;
1025  const int r0 = dr * ith;
1026  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1027 
1028  if (r0 >= M) return;
1029 
1030  const block_q6_K *blocks = (const block_q6_K *)W;
1031  const block_q8_K *x = (const block_q8_K *)x_q8;
1032  const int blocks_per_row = K / QK_K;
1033 
1034  for (int row = r0; row < r1; ++row) {
1035  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
1036  y[row] = dot_q6_k_q8_k_ref(w_row, x, K);
1037  }
1038 }
static float dot_q6_k_q8_k_ref(const block_q6_K *w, const block_q8_K *x, int K)
Scalar dot product for Q6_K x Q8_K.

References dot_q6_k_q8_k_ref(), and QK_K.

◆ gemv_q6_k_q8_k_parallel_simd()

void gemv_q6_k_q8_k_parallel_simd ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K,
int  ith,
int  nth 
)

Parallel SIMD GEMV for Q6_K × Q8_K.

Uses best available SIMD (AVX/SSE) with row prefetching. Caller provides ith/nth from OpenMP region.

Definition at line 1046 of file gemm_kernels_q6k_q8k.c.

1051 {
1052  if (!y || !W || !x_q8 || M <= 0 || K <= 0) return;
1053  if (ith < 0 || nth <= 0 || ith >= nth) return;
1054 
1055  const int dr = (M + nth - 1) / nth;
1056  const int r0 = dr * ith;
1057  const int r1 = (r0 + dr < M) ? (r0 + dr) : M;
1058 
1059  if (r0 >= M) return;
1060 
1061  const block_q6_K *blocks = (const block_q6_K *)W;
1062  const block_q8_K *x = (const block_q8_K *)x_q8;
1063  const int blocks_per_row = K / QK_K;
1064 
1065 #if defined(__AVX__) || defined(__SSSE3__)
1066  /* Prefetch first few rows */
1067  const int PREFETCH_ROWS = 4;
1068  for (int p = 0; p < PREFETCH_ROWS && r0 + p < r1; ++p) {
1069  const char *row_ptr = (const char *)(blocks + (r0 + p) * blocks_per_row);
1070  _mm_prefetch(row_ptr, _MM_HINT_T0);
1071  _mm_prefetch(row_ptr + 64, _MM_HINT_T0);
1072  }
1073 
1074  for (int row = r0; row < r1; ++row) {
1075  /* Prefetch rows ahead */
1076  if (row + PREFETCH_ROWS < r1) {
1077  const char *prefetch_ptr = (const char *)(blocks + (row + PREFETCH_ROWS) * blocks_per_row);
1078  _mm_prefetch(prefetch_ptr, _MM_HINT_T0);
1079  _mm_prefetch(prefetch_ptr + 64, _MM_HINT_T0);
1080  }
1081 
1082  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
1083 #if defined(__AVX2__)
1084  y[row] = dot_q6_k_q8_k_avx2(w_row, x, K);
1085 #elif defined(__AVX__)
1086  y[row] = dot_q6_k_q8_k_avx(w_row, x, K);
1087 #else
1088  y[row] = dot_q6_k_q8_k_sse(w_row, x, K);
1089 #endif
1090  }
1091 #else
1092  /* Fallback to reference */
1093  for (int row = r0; row < r1; ++row) {
1094  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
1095  y[row] = dot_q6_k_q8_k_ref(w_row, x, K);
1096  }
1097 #endif
1098 }

References dot_q6_k_q8_k_ref(), and QK_K.

◆ gemv_q8_0()

void gemv_q8_0 ( float *  y,
const void *  W,
const float *  x,
int  M,
int  K 
)

Auto-dispatch GEMV for Q8_0 weights based on CPU features.

Dispatch priority (best available):

  1. AVX-512 (512-bit vectors) - Intel Skylake-X+
  2. AVX2+FMA (256-bit vectors) - Intel Haswell+
  3. AVX (256-bit vectors) - Intel Sandy Bridge+
  4. SSE4.1 (128-bit vectors) - Intel Nehalem+
  5. Reference (scalar) - Fallback

Uses ck_features.h for standardized feature detection.

Parameters
yOutput vector [M]
WWeight matrix in Q8_0 format [M x K]
xInput vector [K]
MNumber of output rows
KNumber of input columns (hidden dimension)

Definition at line 630 of file gemm_kernels_q8_0.c.

634 {
635 // Dispatch order: AVX512 > AVX2 > AVX > SSE > ref
636 #if defined(__AVX512F__)
637  gemv_q8_0_avx512(y, W, x, M, K);
638 #elif defined(__AVX2__)
639  gemv_q8_0_avx2(y, W, x, M, K);
640 #elif defined(__AVX__)
641  gemv_q8_0_avx(y, W, x, M, K);
642 #elif defined(__SSE4_1__)
643  gemv_q8_0_sse(y, W, x, M, K);
644 #else
645  gemv_q8_0_ref(y, W, x, M, K);
646 #endif
647 }
void gemv_q8_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q8_0 weights (scalar reference)

References gemv_q8_0_ref().

◆ gemv_q8_0_q8_0()

void gemv_q8_0_q8_0 ( float *  y,
const void *  W,
const void *  x_q8,
int  M,
int  K 
)

Matrix-vector multiply with Q8_0 weights and Q8_0 input.

Parameters
yOutput vector [M]
WWeight matrix in Q8_0 format [M x K]
x_q8Input vector in Q8_0 format [K]
MNumber of output rows
KNumber of columns (must be multiple of 32)

Definition at line 1042 of file gemm_kernels_q8_0.c.

1046 {
1047  const block_q8_0 *w_blocks = (const block_q8_0 *)W;
1048  const block_q8_0 *x_blocks = (const block_q8_0 *)x_q8;
1049  const int blocks_per_row = K / QK8_0;
1050 
1051  for (int row = 0; row < M; row++) {
1052  vec_dot_q8_0_q8_0(K, &y[row],
1053  &w_blocks[row * blocks_per_row],
1054  x_blocks);
1055  }
1056 }
void vec_dot_q8_0_q8_0(int n, float *s, const void *vx, const void *vy)
Auto-dispatch quantized dot product Q8_0 x Q8_0.

References QK8_0, and vec_dot_q8_0_q8_0().

◆ im2patch()

void im2patch ( const float *  image,
float *  patches,
int  C,
int  H,
int  W,
int  P 
)

im2patch: Transforms an image into a sequence of flattened patches.

Image Layout: [C, H, W] (Row-major: W is fastest moving) Output Layout: [num_patches, C * P * P]

num_patches = (H/P) * (W/P) P = patch_size

Definition at line 28 of file vision_kernels.c.

31 {
32  int num_patches_h = H / P;
33  int num_patches_w = W / P;
34  int patch_dim = C * P * P;
35 
36  // ph, pw: patch grid coordinates
37  for (int ph = 0; ph < num_patches_h; ++ph) {
38  for (int pw = 0; pw < num_patches_w; ++pw) {
39 
40  int patch_idx = ph * num_patches_w + pw;
41  float *dst_patch = patches + (size_t)patch_idx * patch_dim;
42 
43  // For each patch, grab pixels from all channels
44  for (int c = 0; c < C; ++c) {
45  for (int py = 0; py < P; ++py) {
46  int y = ph * P + py;
47  int x = pw * P;
48 
49  // Input row start in the image
50  const float *src_row = image + (size_t)c * H * W + (size_t)y * W + x;
51 
52  // Destination row in the flattened patch sequence
53  float *dst_row = dst_patch + (size_t)c * P * P + (size_t)py * P;
54 
55  // Copy P pixels (one row of the patch)
56  memcpy(dst_row, src_row, P * sizeof(float));
57  }
58  }
59  }
60  }
61 }

References C.

◆ im2patch_bf16()

void im2patch_bf16 ( const uint16_t *  image,
uint16_t *  patches,
int  C,
int  H,
int  W,
int  P 
)

Definition at line 22 of file vision_kernels_bf16.c.

28 {
29  if (!image || !patches || C <= 0 || H <= 0 || W <= 0 || P <= 0) {
30  return;
31  }
32 
33  int num_patches_h = H / P;
34  int num_patches_w = W / P;
35  int patch_dim = C * P * P;
36 
37  for (int ph = 0; ph < num_patches_h; ++ph) {
38  for (int pw = 0; pw < num_patches_w; ++pw) {
39  int patch_idx = ph * num_patches_w + pw;
40  uint16_t *dst_patch = patches + (size_t)patch_idx * (size_t)patch_dim;
41 
42  for (int c = 0; c < C; ++c) {
43  for (int py = 0; py < P; ++py) {
44  int y = ph * P + py;
45  int x = pw * P;
46 
47  const uint16_t *src_row = image + (size_t)c * (size_t)H * (size_t)W + (size_t)y * (size_t)W + (size_t)x;
48  uint16_t *dst_row = dst_patch + (size_t)c * (size_t)P * (size_t)P + (size_t)py * (size_t)P;
49 
50  memcpy(dst_row, src_row, (size_t)P * sizeof(uint16_t));
51  }
52  }
53  }
54  }
55 }

References C.

◆ kv_cache_repack_head_major_inplace()

void kv_cache_repack_head_major_inplace ( float *  buf,
int  num_heads,
int  tokens,
int  cache_capacity,
int  aligned_head_dim 
)

Definition at line 28 of file kv_cache_kernels.c.

33 {
34  if (!buf) {
35  return;
36  }
37  if (num_heads <= 0 || tokens <= 0 || cache_capacity <= 0 || aligned_head_dim <= 0) {
38  return;
39  }
40  if (tokens > cache_capacity) {
41  tokens = cache_capacity;
42  }
43  if (tokens == cache_capacity) {
44  return;
45  }
46 
47  const size_t old_head_stride = (size_t)tokens * (size_t)aligned_head_dim;
48  const size_t new_head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
49  const size_t bytes = (size_t)tokens * (size_t)aligned_head_dim * sizeof(float);
50 
51  // Move head blocks from high to low to avoid overwriting source data
52  // for heads that have not yet been moved.
53  for (int h = num_heads - 1; h >= 0; --h) {
54  float *src = buf + (size_t)h * old_head_stride;
55  float *dst = buf + (size_t)h * new_head_stride;
56  memmove(dst, src, bytes);
57  }
58 }

Referenced by qwen2_0_5b_decode_forward_prefill_impl().

◆ kv_cache_store()

void kv_cache_store ( float *__restrict  kv_cache_k,
float *__restrict  kv_cache_v,
const float *__restrict  k,
const float *__restrict  v,
int  layer,
int  pos,
int  num_kv_heads,
int  head_dim,
int  max_seq_len 
)

Definition at line 101 of file kv_cache_kernels.c.

110 {
111  (void)layer;
113  kv_cache_k, kv_cache_v,
114  num_kv_heads,
115  pos,
116  max_seq_len,
117  head_dim,
118  head_dim);
119 }
void kv_cache_write_head_major(const float *__restrict k_token, const float *__restrict v_token, float *__restrict k_cache, float *__restrict v_cache, int num_kv_heads, int token_index, int cache_capacity, int head_dim, int aligned_head_dim)

References kv_cache_write_head_major().

◆ kv_cache_write_head_major()

void kv_cache_write_head_major ( const float *__restrict  k_token,
const float *__restrict  v_token,
float *__restrict  k_cache,
float *__restrict  v_cache,
int  num_kv_heads,
int  token_index,
int  cache_capacity,
int  head_dim,
int  aligned_head_dim 
)

Definition at line 60 of file kv_cache_kernels.c.

69 {
70  if (!k_token || !v_token || !k_cache || !v_cache) {
71  return;
72  }
73  if (num_kv_heads <= 0 || token_index < 0 || cache_capacity <= 0) {
74  return;
75  }
76  if (token_index >= cache_capacity || head_dim <= 0 || aligned_head_dim <= 0) {
77  return;
78  }
79 
80  const size_t head_stride = (size_t)cache_capacity * (size_t)aligned_head_dim;
81  const size_t token_stride = (size_t)aligned_head_dim;
82 
83  for (int h = 0; h < num_kv_heads; ++h) {
84  const float *k_src = k_token + (size_t)h * token_stride;
85  const float *v_src = v_token + (size_t)h * token_stride;
86 
87  float *k_dst = k_cache + (size_t)h * head_stride + (size_t)token_index * token_stride;
88  float *v_dst = v_cache + (size_t)h * head_stride + (size_t)token_index * token_stride;
89 
90  for (int d = 0; d < head_dim; ++d) {
91  k_dst[d] = k_src[d];
92  v_dst[d] = v_src[d];
93  }
94  for (int d = head_dim; d < aligned_head_dim; ++d) {
95  k_dst[d] = 0.0f;
96  v_dst[d] = 0.0f;
97  }
98  }
99 }

Referenced by ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), kv_cache_store(), mega_fused_attention_decode(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), and qwen2_0_5b_decode_layer_9_decode().

◆ layernorm_backward_kernel()

void layernorm_backward_kernel ( const float *  d_output,
const float *  input,
const float *  gamma,
const float *  mean,
const float *  rstd,
float *  d_input,
float *  d_gamma,
float *  d_beta,
int  tokens,
int  d_model,
int  aligned_embed_dim 
)

Definition at line 668 of file layernorm_kernels.c.

677 {
678  int T = tokens;
679  int D = d_model;
680  int aligned_D = aligned_embed_dim;
681 
682  // Per-token input gradients
683  for (int t = 0; t < T; ++t) {
684  float mean_t = mean[t];
685  float rstd_t = rstd[t];
686 
687  float d_y_gamma_sum = 0.0f;
688  float d_y_gamma_xhat_sum = 0.0f;
689 
690  // First pass: compute sums
691  for (int d = 0; d < D; ++d) {
692  float x = input[t * aligned_D + d];
693  float x_hat = (x - mean_t) * rstd_t;
694  float d_y = d_output[t * aligned_D + d];
695  float d_y_gamma = d_y * gamma[d];
696 
697  d_y_gamma_sum += d_y_gamma;
698  d_y_gamma_xhat_sum += d_y_gamma * x_hat;
699  }
700 
701  // Second pass: compute input gradients
702  float scale = rstd_t / (float)D;
703  for (int d = 0; d < D; ++d) {
704  float x = input[t * aligned_D + d];
705  float x_hat = (x - mean_t) * rstd_t;
706  float d_y = d_output[t * aligned_D + d];
707 
708  d_input[t * aligned_D + d] =
709  scale * ((float)D * d_y * gamma[d] - d_y_gamma_sum - x_hat * d_y_gamma_xhat_sum);
710  }
711 
712  // Zero padding for aligned dimension beyond D
713  for (int d = D; d < aligned_D; ++d) {
714  d_input[t * aligned_D + d] = 0.0f;
715  }
716  }
717 
718  // Parameter gradients (gamma, beta)
719  for (int d = 0; d < D; ++d) {
720  float gamma_grad = 0.0f;
721  float beta_grad = 0.0f;
722 
723  for (int t = 0; t < T; ++t) {
724  float x = input[t * aligned_D + d];
725  float x_hat = (x - mean[t]) * rstd[t];
726  float d_y = d_output[t * aligned_D + d];
727 
728  gamma_grad += d_y * x_hat;
729  beta_grad += d_y;
730  }
731 
732  d_gamma[d] += gamma_grad;
733  d_beta[d] += beta_grad;
734  }
735 }

Referenced by layernorm_backward_kernel_bf16().

◆ layernorm_backward_kernel_bf16()

void layernorm_backward_kernel_bf16 ( const uint16_t *  d_output,
const uint16_t *  input,
const float *  gamma,
const float *  mean,
const float *  rstd,
uint16_t *  d_input,
float *  d_gamma,
float *  d_beta,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float *  scratch_d_output,
float *  scratch_input,
float *  scratch_d_input 
)

Definition at line 84 of file layernorm_kernels_bf16.c.

96 {
97  if (!scratch_d_output || !scratch_input || !scratch_d_input) return;
98 
99  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
100 
101  bf16_tensor_to_float(d_output, scratch_d_output, total);
102  bf16_tensor_to_float(input, scratch_input, total);
103 
104  layernorm_backward_kernel(scratch_d_output, scratch_input, gamma, mean, rstd,
105  scratch_d_input, d_gamma, d_beta,
106  tokens, d_model, aligned_embed_dim);
107 
108  float_tensor_to_bf16(scratch_d_input, d_input, total);
109 }
void layernorm_backward_kernel(const float *d_output, const float *input, const float *gamma, const float *mean, const float *rstd, float *d_input, float *d_gamma, float *d_beta, int tokens, int d_model, int aligned_embed_dim)

References bf16_tensor_to_float(), float_tensor_to_bf16(), and layernorm_backward_kernel().

◆ layernorm_forward_rolled_slice()

void layernorm_forward_rolled_slice ( const float *__restrict  input_slice_base,
const float *__restrict  gamma,
const float *__restrict  beta,
float *__restrict  output_slice_base,
float *__restrict  mean_cache_slice,
float *__restrict  rstd_cache_slice,
int  num_tokens_in_slice,
int  d_model,
int  aligned_embed_dim,
float  eps 
)

Definition at line 274 of file layernorm_kernels.c.

284 {
285 #if defined(__AVX512F__)
286  layernorm_forward_rolled_slice_avx512(input_slice_base, gamma, beta,
287  output_slice_base, mean_cache_slice, rstd_cache_slice,
288  num_tokens_in_slice, d_model, aligned_embed_dim, eps);
289 #elif defined(__AVX2__) || defined(__AVX__)
290  layernorm_forward_rolled_slice_avx256(input_slice_base, gamma, beta,
291  output_slice_base, mean_cache_slice, rstd_cache_slice,
292  num_tokens_in_slice, d_model, aligned_embed_dim, eps);
293 #else
294  layernorm_naive_serial(input_slice_base, gamma, beta,
295  output_slice_base, mean_cache_slice, rstd_cache_slice,
296  num_tokens_in_slice, d_model, aligned_embed_dim, eps);
297 #endif
298 }
void layernorm_naive_serial(const float *input, const float *gamma, const float *beta, float *output, float *mean_cache, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)

References layernorm_naive_serial().

Referenced by layernorm_forward_rolled_slice_bf16().

◆ layernorm_forward_rolled_slice_bf16()

void layernorm_forward_rolled_slice_bf16 ( const uint16_t *__restrict  input_slice_base,
const float *__restrict  gamma,
const float *__restrict  beta,
uint16_t *__restrict  output_slice_base,
float *__restrict  mean_cache_slice,
float *__restrict  rstd_cache_slice,
int  num_tokens_in_slice,
int  d_model,
int  aligned_embed_dim,
float  eps,
float *  scratch_input,
float *  scratch_output 
)

Definition at line 30 of file layernorm_kernels_bf16.c.

42 {
43  if (!scratch_input || !scratch_output) return;
44 
45  size_t total = (size_t)num_tokens_in_slice * (size_t)aligned_embed_dim;
46 
47  bf16_tensor_to_float(input_slice_base, scratch_input, total);
48  layernorm_forward_rolled_slice(scratch_input, gamma, beta,
49  scratch_output, mean_cache_slice, rstd_cache_slice,
50  num_tokens_in_slice, d_model, aligned_embed_dim, eps);
51  float_tensor_to_bf16(scratch_output, output_slice_base, total);
52 }
void layernorm_forward_rolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, int aligned_embed_dim, float eps)

References bf16_tensor_to_float(), float_tensor_to_bf16(), and layernorm_forward_rolled_slice().

◆ layernorm_forward_unrolled_slice()

void layernorm_forward_unrolled_slice ( const float *__restrict  input_slice_base,
const float *__restrict  gamma,
const float *__restrict  beta,
float *__restrict  output_slice_base,
float *__restrict  mean_cache_slice,
float *__restrict  rstd_cache_slice,
int  num_tokens_in_slice,
int  d_model,
float  eps 
)

Definition at line 598 of file layernorm_kernels.c.

607 {
608 #if defined(__AVX512F__)
609  layernorm_forward_unrolled_slice_avx512(input_slice_base, gamma, beta,
610  output_slice_base, mean_cache_slice, rstd_cache_slice,
611  num_tokens_in_slice, d_model, eps);
612 #elif defined(__AVX2__) || defined(__AVX__)
613  layernorm_forward_unrolled_slice_avx256(input_slice_base, gamma, beta,
614  output_slice_base, mean_cache_slice, rstd_cache_slice,
615  num_tokens_in_slice, d_model, eps);
616 #else
617  layernorm_forward_unrolled_slice_scalar(input_slice_base, gamma, beta,
618  output_slice_base, mean_cache_slice, rstd_cache_slice,
619  num_tokens_in_slice, d_model, eps);
620 #endif
621 }
static void layernorm_forward_unrolled_slice_scalar(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)

References layernorm_forward_unrolled_slice_scalar().

Referenced by layernorm_forward_unrolled_slice_bf16().

◆ layernorm_forward_unrolled_slice_bf16()

void layernorm_forward_unrolled_slice_bf16 ( const uint16_t *__restrict  input_slice_base,
const float *__restrict  gamma,
const float *__restrict  beta,
uint16_t *__restrict  output_slice_base,
float *__restrict  mean_cache_slice,
float *__restrict  rstd_cache_slice,
int  num_tokens_in_slice,
int  d_model,
float  eps,
float *  scratch_input,
float *  scratch_output 
)

Definition at line 57 of file layernorm_kernels_bf16.c.

68 {
69  if (!scratch_input || !scratch_output) return;
70 
71  size_t total = (size_t)num_tokens_in_slice * (size_t)d_model;
72 
73  bf16_tensor_to_float(input_slice_base, scratch_input, total);
74  layernorm_forward_unrolled_slice(scratch_input, gamma, beta,
75  scratch_output, mean_cache_slice, rstd_cache_slice,
76  num_tokens_in_slice, d_model, eps);
77  float_tensor_to_bf16(scratch_output, output_slice_base, total);
78 }
void layernorm_forward_unrolled_slice(const float *__restrict input_slice_base, const float *__restrict gamma, const float *__restrict beta, float *__restrict output_slice_base, float *__restrict mean_cache_slice, float *__restrict rstd_cache_slice, int num_tokens_in_slice, int d_model, float eps)

References bf16_tensor_to_float(), float_tensor_to_bf16(), and layernorm_forward_unrolled_slice().

◆ layernorm_naive_serial()

void layernorm_naive_serial ( const float *  input,
const float *  gamma,
const float *  beta,
float *  output,
float *  mean_cache,
float *  rstd_cache,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps 
)

Definition at line 51 of file layernorm_kernels.c.

59 {
60  for (int t = 0; t < tokens; ++t) {
61  const float *in_ptr = input + t * aligned_embed_dim;
62  float *out_ptr = output + t * aligned_embed_dim;
63 
64  float sum_val = 0.0f;
65  for (int i = 0; i < d_model; ++i) {
66  sum_val += in_ptr[i];
67  }
68  float mean = sum_val / (float)d_model;
69 
70  float sum_sq_diff = 0.0f;
71  for (int i = 0; i < d_model; ++i) {
72  float diff = in_ptr[i] - mean;
73  sum_sq_diff += diff * diff;
74  }
75  float variance = sum_sq_diff / (float)d_model + eps;
76 
77  double var_double = (double)variance;
78  float inv_std = (float)(1.0 / sqrt(var_double));
79 
80  for (int i = 0; i < d_model; ++i) {
81  float normalized_val = (in_ptr[i] - mean) * inv_std;
82  out_ptr[i] = normalized_val * gamma[i] + beta[i];
83  }
84 
85  if (mean_cache) {
86  mean_cache[t] = mean;
87  }
88  if (rstd_cache) {
89  rstd_cache[t] = inv_std;
90  }
91  /* Keep aligned padding quiet so future GEMMs see deterministic memory. */
92  if (aligned_embed_dim > d_model) {
93  /* Keep padded lanes zeroed so subsequent GEMMs never read stale data. */
94  for (int i = d_model; i < aligned_embed_dim; ++i) {
95  out_ptr[i] = 0.0f;
96  }
97  }
98  }
99 }

Referenced by layernorm_forward_rolled_slice().

◆ layernorm_naive_serial_matched_precision()

void layernorm_naive_serial_matched_precision ( const float *  input,
const float *  gamma,
const float *  beta,
float *  output,
float *  mean_cache,
float *  rstd_cache,
int  tokens,
int  d_model,
float  eps 
)

Definition at line 624 of file layernorm_kernels.c.

631 {
632  for (int t = 0; t < tokens; ++t) {
633  const float *in_ptr = input + t * d_model;
634  float *out_ptr = output + t * d_model;
635 
636  float sum_val = 0.0f;
637  for (int i = 0; i < d_model; ++i) {
638  sum_val += in_ptr[i];
639  }
640  float mean = sum_val / (float)d_model;
641 
642  float sum_sq_diff = 0.0f;
643  for (int i = 0; i < d_model; ++i) {
644  float diff = in_ptr[i] - mean;
645  sum_sq_diff += diff * diff;
646  }
647  float variance = sum_sq_diff / (float)d_model + eps;
648 
649  double var_double = (double)variance;
650  float inv_std = (float)(1.0 / sqrt(var_double));
651 
652  for (int i = 0; i < d_model; ++i) {
653  float normalized_val = (in_ptr[i] - mean) * inv_std;
654  out_ptr[i] = normalized_val * gamma[i] + beta[i];
655  }
656 
657  if (mean_cache) {
658  mean_cache[t] = mean;
659  }
660  if (rstd_cache) {
661  rstd_cache[t] = inv_std;
662  }
663  }
664 }

Referenced by layernorm_forward_unrolled_slice_scalar().

◆ mlp_token_parallel()

void mlp_token_parallel ( const float *  input,
const float *  W_fc1,
const float *  b_fc1,
const float *  W_fc2,
const float *  b_fc2,
float *  fc1_output,
float *  output,
int  T,
int  aligned_dim,
int  num_threads 
)

Definition at line 41 of file mlp_kernels.c.

51 {
52  int D = aligned_dim;
53  int fourD = 4 * D;
54 
55  // FC1: [T × D] · [D × 4D] -> [T × 4D]
56  // Our GEMM layout: A[M×K], B[N×K], so B is [4D × D].
57  gemm_blocked_serial(input, W_fc1, b_fc1,
58  fc1_output,
59  T, // M
60  fourD, // N
61  D); // K
62 
63  // GELU in-place on FC1 output
64  gelu_fast_inplace(fc1_output, (size_t)T * (size_t)fourD);
65 
66  // FC2: [T × 4D] · [4D × D] -> [T × D]
67  gemm_blocked_serial(fc1_output, W_fc2, b_fc2,
68  output,
69  T, // M
70  D, // N
71  fourD); // K
72 }
void gelu_fast_inplace(float *data, size_t n)
Definition: gelu_kernels.c:132
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:661

References gelu_fast_inplace(), and gemm_blocked_serial().

◆ mlp_token_parallel_bf16()

void mlp_token_parallel_bf16 ( const uint16_t *  input,
const uint16_t *  W_fc1,
const uint16_t *  b_fc1,
const uint16_t *  W_fc2,
const uint16_t *  b_fc2,
float *  fc1_output,
float *  output,
int  T,
int  aligned_dim,
int  num_threads,
float *  scratch_bias1_f,
float *  scratch_bias2_f,
uint16_t *  scratch_fc1_bf16 
)

Optimized MLP Forward (BF16 weights, FP32 activations)

Caller-provided scratch buffers: scratch_bias1_f: [4*D] floats scratch_bias2_f: [D] floats scratch_fc1_bf16: [T * 4*D] uint16_t (BF16)

Definition at line 91 of file mlp_kernels_bf16.c.

104 {
105  if (!input || !W_fc1 || !b_fc1 || !W_fc2 || !b_fc2 || !fc1_output || !output) return;
106  if (!scratch_bias1_f || !scratch_bias2_f || !scratch_fc1_bf16) return;
107 
108  (void)num_threads;
109  const int D = aligned_dim;
110  const int fourD = 4 * D;
111 
112  /* Convert biases to FP32 */
113  for (int i = 0; i < fourD; ++i) {
114  scratch_bias1_f[i] = bf16_to_float(b_fc1[i]);
115  }
116  for (int i = 0; i < D; ++i) {
117  scratch_bias2_f[i] = bf16_to_float(b_fc2[i]);
118  }
119 
120  /* FC1: [T, D] x [4D, D].T -> [T, 4D] */
121  gemm_bf16_fp32out(input, W_fc1, scratch_bias1_f, fc1_output, T, fourD, D);
122 
123  /* GELU activation */
124 #if defined(__AVX512F__)
125  #pragma omp parallel for
126  for (int t = 0; t < T; ++t) {
127  float *row = fc1_output + (size_t)t * fourD;
128  int j = 0;
129  for (; j <= fourD - 16; j += 16) {
130  __m512 x = _mm512_loadu_ps(row + j);
131  _mm512_storeu_ps(row + j, gelu_avx512(x));
132  }
133  for (; j < fourD; ++j) {
134  row[j] = gelu_scalar(row[j]);
135  }
136  }
137 #else
138  for (int t = 0; t < T; ++t) {
139  for (int j = 0; j < fourD; ++j) {
140  fc1_output[t * fourD + j] = gelu_scalar(fc1_output[t * fourD + j]);
141  }
142  }
143 #endif
144 
145  /* Convert FP32 activations to BF16 */
146 #if defined(__AVX512F__)
147  #pragma omp parallel for
148  for (int t = 0; t < T; ++t) {
149  float *src = fc1_output + (size_t)t * fourD;
150  uint16_t *dst = scratch_fc1_bf16 + (size_t)t * fourD;
151  int j = 0;
152  for (; j <= fourD - 16; j += 16) {
153  __m512 fp32 = _mm512_loadu_ps(src + j);
154  __m512i as_int = _mm512_castps_si512(fp32);
155  __m512i lsb = _mm512_srli_epi32(as_int, 16);
156  lsb = _mm512_and_si512(lsb, _mm512_set1_epi32(1));
157  __m512i rounding = _mm512_add_epi32(_mm512_set1_epi32(0x7FFF), lsb);
158  __m512i rounded = _mm512_add_epi32(as_int, rounding);
159  __m512i shifted = _mm512_srli_epi32(rounded, 16);
160  __m256i bf16 = _mm512_cvtepi32_epi16(shifted);
161  _mm256_storeu_si256((__m256i *)(dst + j), bf16);
162  }
163  for (; j < fourD; ++j) {
164  dst[j] = float_to_bf16(src[j]);
165  }
166  }
167 #else
168  for (size_t i = 0; i < (size_t)T * fourD; ++i) {
169  scratch_fc1_bf16[i] = float_to_bf16(fc1_output[i]);
170  }
171 #endif
172 
173  /* FC2: BF16 GEMM with FP32 output */
174  gemm_bf16_fp32out(scratch_fc1_bf16, W_fc2, scratch_bias2_f, output, T, D, fourD);
175 }
static float gelu_scalar(float x)
void gemm_bf16_fp32out(const uint16_t *A, const uint16_t *B, const float *bias, float *C, int M, int N, int K)

References bf16_to_float(), float_to_bf16(), gelu_scalar(), and gemm_bf16_fp32out().

◆ mlp_token_parallel_bf16_fp32act()

void mlp_token_parallel_bf16_fp32act ( const uint16_t *  input,
const uint16_t *  W_fc1,
const uint16_t *  b_fc1,
const uint16_t *  W_fc2,
const uint16_t *  b_fc2,
float *  fc1_output,
float *  output,
int  T,
int  aligned_dim,
int  num_threads,
float *  scratch_input_f,
float *  scratch_bias1_f,
float *  scratch_bias2_f,
uint16_t *  scratch_fc1_bf16 
)

Alternative: Fully FP32 activations throughout

Caller-provided scratch buffers: scratch_input_f: [T * D] floats scratch_bias1_f: [4*D] floats scratch_bias2_f: [D] floats scratch_fc1_bf16: [T * 4*D] uint16_t (BF16)

Definition at line 186 of file mlp_kernels_bf16.c.

200 {
201  if (!input || !W_fc1 || !b_fc1 || !W_fc2 || !b_fc2 || !fc1_output || !output) return;
202  if (!scratch_input_f || !scratch_bias1_f || !scratch_bias2_f || !scratch_fc1_bf16) return;
203 
204  (void)num_threads;
205  const int D = aligned_dim;
206  const int fourD = 4 * D;
207 
208  /* Convert input and biases to FP32 */
209  bf16_tensor_to_float(input, scratch_input_f, (size_t)T * D);
210  bf16_tensor_to_float(b_fc1, scratch_bias1_f, fourD);
211  bf16_tensor_to_float(b_fc2, scratch_bias2_f, D);
212 
213  /* FC1 */
214  gemm_bf16_fp32out(input, W_fc1, scratch_bias1_f, fc1_output, T, fourD, D);
215 
216  /* GELU */
217 #if defined(__AVX512F__)
218  #pragma omp parallel for
219  for (int t = 0; t < T; ++t) {
220  float *row = fc1_output + (size_t)t * fourD;
221  int j = 0;
222  for (; j <= fourD - 16; j += 16) {
223  __m512 x = _mm512_loadu_ps(row + j);
224  _mm512_storeu_ps(row + j, gelu_avx512(x));
225  }
226  for (; j < fourD; ++j) {
227  row[j] = gelu_scalar(row[j]);
228  }
229  }
230 #else
231  for (size_t i = 0; i < (size_t)T * fourD; ++i) {
232  fc1_output[i] = gelu_scalar(fc1_output[i]);
233  }
234 #endif
235 
236  /* Convert fc1_output to BF16 for FC2 */
237  float_tensor_to_bf16(fc1_output, scratch_fc1_bf16, (size_t)T * fourD);
238  gemm_bf16_fp32out(scratch_fc1_bf16, W_fc2, scratch_bias2_f, output, T, D, fourD);
239 }

References bf16_tensor_to_float(), float_tensor_to_bf16(), gelu_scalar(), and gemm_bf16_fp32out().

◆ mlp_token_parallel_exact()

void mlp_token_parallel_exact ( const float *  input,
const float *  W_fc1,
const float *  b_fc1,
const float *  W_fc2,
const float *  b_fc2,
float *  fc1_output,
float *  output,
int  T,
int  aligned_dim,
int  num_threads 
)

Definition at line 76 of file mlp_kernels.c.

86 {
87  (void)num_threads;
88  int D = aligned_dim;
89  int fourD = 4 * D;
90 
91  // FC1: [T × D] · [D × 4D] -> [T × 4D]
92  gemm_blocked_serial(input, W_fc1, b_fc1,
93  fc1_output,
94  T, // M
95  fourD, // N
96  D); // K
97 
98  // Exact GELU using standard library tanhf
99  gelu_exact_inplace(fc1_output, (size_t)T * (size_t)fourD);
100 
101  // FC2: [T × 4D] · [4D × D] -> [T × D]
102  gemm_blocked_serial(fc1_output, W_fc2, b_fc2,
103  output,
104  T, // M
105  D, // N
106  fourD); // K
107 }

References gelu_exact_inplace(), and gemm_blocked_serial().

◆ moe_accumulate_expert_f32()

void moe_accumulate_expert_f32 ( float *  output,
const float *  expert_output,
float  routing_weight,
int  hidden_dim 
)

Accumulate expert output: output += routing_weight * expert_output.

Parameters
outputToken output buffer [hidden_dim], accumulated in place
expert_outputExpert's output for this token [hidden_dim]
routing_weightSoftmax routing weight for this expert
hidden_dimHidden dimension

Definition at line 256 of file axpy_kernels.c.

260 {
261  axpy_f32(output, expert_output, routing_weight, hidden_dim);
262 }

References axpy_f32().

◆ patch2im()

void patch2im ( const float *  d_patches,
float *  d_image,
int  C,
int  H,
int  W,
int  P 
)

patch2im: Accumulates gradients from patches back into the image. (Backward pass)

d_patches: [num_patches, C * P * P] d_image: [C, H, W] (Accumulated)

Definition at line 69 of file vision_kernels.c.

72 {
73  int num_patches_h = H / P;
74  int num_patches_w = W / P;
75  int patch_dim = C * P * P;
76 
77  // Zero out the image first as we are accumulating gradients
78  memset(d_image, 0, (size_t)C * H * W * sizeof(float));
79 
80  for (int ph = 0; ph < num_patches_h; ++ph) {
81  for (int pw = 0; pw < num_patches_w; ++pw) {
82 
83  int patch_idx = ph * num_patches_w + pw;
84  const float *src_patch = d_patches + (size_t)patch_idx * patch_dim;
85 
86  for (int c = 0; c < C; ++c) {
87  for (int py = 0; py < P; ++py) {
88  int y = ph * P + py;
89  int x = pw * P;
90 
91  float *dst_row = d_image + (size_t)c * H * W + (size_t)y * W + x;
92  const float *src_row = src_patch + (size_t)c * P * P + (size_t)py * P;
93 
94  // Add the patch gradient to the image gradient
95  for (int px = 0; px < P; ++px) {
96  dst_row[px] += src_row[px];
97  }
98  }
99  }
100  }
101  }
102 }

References C.

◆ patch2im_bf16()

void patch2im_bf16 ( const uint16_t *  d_patches,
uint16_t *  d_image,
int  C,
int  H,
int  W,
int  P 
)

Definition at line 57 of file vision_kernels_bf16.c.

63 {
64  if (!d_patches || !d_image || C <= 0 || H <= 0 || W <= 0 || P <= 0) {
65  return;
66  }
67 
68  int num_patches_h = H / P;
69  int num_patches_w = W / P;
70  int patch_dim = C * P * P;
71 
72  memset(d_image, 0, (size_t)C * (size_t)H * (size_t)W * sizeof(uint16_t));
73 
74  for (int ph = 0; ph < num_patches_h; ++ph) {
75  for (int pw = 0; pw < num_patches_w; ++pw) {
76  int patch_idx = ph * num_patches_w + pw;
77  const uint16_t *src_patch = d_patches + (size_t)patch_idx * (size_t)patch_dim;
78 
79  for (int c = 0; c < C; ++c) {
80  for (int py = 0; py < P; ++py) {
81  int y = ph * P + py;
82  int x = pw * P;
83 
84  uint16_t *dst_row = d_image + (size_t)c * (size_t)H * (size_t)W + (size_t)y * (size_t)W + (size_t)x;
85  const uint16_t *src_row = src_patch + (size_t)c * (size_t)P * (size_t)P + (size_t)py * (size_t)P;
86 
87  for (int px = 0; px < P; ++px) {
88  float acc = bf16_to_float(dst_row[px]) + bf16_to_float(src_row[px]);
89  dst_row[px] = float_to_bf16(acc);
90  }
91  }
92  }
93  }
94  }
95 }

References bf16_to_float(), C, and float_to_bf16().

◆ quantize_batch_q8_0()

void quantize_batch_q8_0 ( const float *  x,
void *  vy,
int  num_rows,
int  k 
)

Batch quantize FP32 to Q8_0 format (row-major output)

Quantizes multiple rows of FP32 data to Q8_0 format, placing each row's Q8_0 output at the correct byte offset for GEMM compatibility.

Memory layout: Input: [num_rows, k] FP32, row-major (stride = k * sizeof(float)) Output: [num_rows, q8_row_bytes] Q8_0, row-major (stride = q8_row_bytes)

where q8_row_bytes = (k / 32) * sizeof(block_q8_0) = (k / 32) * 34

Parameters
xInput FP32 values [num_rows * k]
vyOutput Q8_0 blocks [num_rows * (k/32) blocks]
num_rowsNumber of rows (batch size / tokens)
kElements per row (must be multiple of 32)

Definition at line 192 of file gemm_kernels_q8_0.c.

193 {
194  const size_t row_bytes_in = (size_t)k * sizeof(float);
195  const size_t row_bytes_out = (size_t)(k / QK8_0) * sizeof(block_q8_0);
196 
197  uint8_t *out = (uint8_t *)vy;
198  const uint8_t *in = (const uint8_t *)x;
199 
200  for (int row = 0; row < num_rows; ++row) {
202  (const float *)(in + row * row_bytes_in),
203  (void *)(out + row * row_bytes_out),
204  k
205  );
206  }
207 }
void quantize_row_q8_0(const float *x, void *vy, int k)
Quantize FP32 to Q8_0 format (scalar reference)

References QK8_0, and quantize_row_q8_0().

◆ quantize_batch_q8_k()

void quantize_batch_q8_k ( const float *  x,
void *  vy,
int  num_rows,
int  k 
)

Batch quantize FP32 to Q8_K format (row-major output)

Same as quantize_batch_q8_0 but for Q8_K format (super-blocks).

Parameters
xInput FP32 values [num_rows * k]
vyOutput Q8_K blocks
num_rowsNumber of rows (batch size / tokens)
kElements per row (must be multiple of 256)

Definition at line 219 of file gemm_kernels_q8_0.c.

220 {
221  /* Q8_K: 256 elements per super-block, each block is larger */
222  const size_t row_bytes_in = (size_t)k * sizeof(float);
223  /* Q8_K block size = 2 (d) + 256 (qs) + 32 (bsums/2) = ~274 bytes for 256 elements */
224  /* Actual: sizeof(block_q8_K) from ckernel_quant.h */
225  const size_t row_bytes_out = (size_t)(k / 256) * sizeof(block_q8_K);
226 
227  uint8_t *out = (uint8_t *)vy;
228  const uint8_t *in = (const uint8_t *)x;
229 
230  for (int row = 0; row < num_rows; ++row) {
232  (const float *)(in + row * row_bytes_in),
233  (void *)(out + row * row_bytes_out),
234  k
235  );
236  }
237 }
void quantize_row_q8_k(const float *x, void *vy, int k)

References quantize_row_q8_k().

◆ quantize_row_q8_0()

void quantize_row_q8_0 ( const float *  x,
void *  vy,
int  k 
)

Quantize FP32 to Q8_0 format (scalar reference)

Parameters
xInput FP32 values
vyOutput Q8_0 blocks
kNumber of elements (must be multiple of 32)

Definition at line 59 of file gemm_kernels_q8_0.c.

60 {
61  block_q8_0 *y = (block_q8_0 *)vy;
62  const int nb = k / QK8_0; /* QK8_0 = 32 */
63 
64 #if defined(__AVX__)
65  const __m256 sign_bit = _mm256_set1_ps(-0.0f);
66  const __m256 v_half = _mm256_set1_ps(0.5f);
67  const __m256 v_min = _mm256_set1_ps(-127.0f);
68  const __m256 v_max = _mm256_set1_ps(127.0f);
69 
70  for (int i = 0; i < nb; i++) {
71  __m256 v0 = _mm256_loadu_ps(x + 0);
72  __m256 v1 = _mm256_loadu_ps(x + 8);
73  __m256 v2 = _mm256_loadu_ps(x + 16);
74  __m256 v3 = _mm256_loadu_ps(x + 24);
75  x += QK8_0;
76 
77  __m256 max_abs = _mm256_andnot_ps(sign_bit, v0);
78  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v1));
79  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v2));
80  max_abs = _mm256_max_ps(max_abs, _mm256_andnot_ps(sign_bit, v3));
81 
82  __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(max_abs, 1),
83  _mm256_castps256_ps128(max_abs));
84  max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
85  max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
86  const float max_scalar = _mm_cvtss_f32(max4);
87 
88  const float d = max_scalar / 127.0f;
89  const float id = max_scalar != 0.0f ? 127.0f / max_scalar : 0.0f;
90  y[i].d = CK_FP32_TO_FP16(d);
91 
92  const __m256 mul = _mm256_set1_ps(id);
93  v0 = _mm256_mul_ps(v0, mul);
94  v1 = _mm256_mul_ps(v1, mul);
95  v2 = _mm256_mul_ps(v2, mul);
96  v3 = _mm256_mul_ps(v3, mul);
97 
98  v0 = _mm256_min_ps(_mm256_max_ps(v0, v_min), v_max);
99  v1 = _mm256_min_ps(_mm256_max_ps(v1, v_min), v_max);
100  v2 = _mm256_min_ps(_mm256_max_ps(v2, v_min), v_max);
101  v3 = _mm256_min_ps(_mm256_max_ps(v3, v_min), v_max);
102 
103  /* Round half away from zero to match the scalar path */
104  v0 = _mm256_add_ps(v0, _mm256_or_ps(_mm256_and_ps(v0, sign_bit), v_half));
105  v1 = _mm256_add_ps(v1, _mm256_or_ps(_mm256_and_ps(v1, sign_bit), v_half));
106  v2 = _mm256_add_ps(v2, _mm256_or_ps(_mm256_and_ps(v2, sign_bit), v_half));
107  v3 = _mm256_add_ps(v3, _mm256_or_ps(_mm256_and_ps(v3, sign_bit), v_half));
108 
109  __m256i i0 = _mm256_cvttps_epi32(v0);
110  __m256i i1 = _mm256_cvttps_epi32(v1);
111  __m256i i2 = _mm256_cvttps_epi32(v2);
112  __m256i i3 = _mm256_cvttps_epi32(v3);
113 
114 #if defined(__AVX2__)
115  i0 = _mm256_packs_epi32(i0, i1);
116  i2 = _mm256_packs_epi32(i2, i3);
117  i0 = _mm256_packs_epi16(i0, i2);
118 
119  const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
120  i0 = _mm256_permutevar8x32_epi32(i0, perm);
121  _mm256_storeu_si256((__m256i *)y[i].qs, i0);
122 #else
123  __m128i ni0 = _mm256_castsi256_si128(i0);
124  __m128i ni1 = _mm256_extractf128_si256(i0, 1);
125  __m128i ni2 = _mm256_castsi256_si128(i1);
126  __m128i ni3 = _mm256_extractf128_si256(i1, 1);
127  __m128i ni4 = _mm256_castsi256_si128(i2);
128  __m128i ni5 = _mm256_extractf128_si256(i2, 1);
129  __m128i ni6 = _mm256_castsi256_si128(i3);
130  __m128i ni7 = _mm256_extractf128_si256(i3, 1);
131 
132  ni0 = _mm_packs_epi32(ni0, ni1);
133  ni2 = _mm_packs_epi32(ni2, ni3);
134  ni4 = _mm_packs_epi32(ni4, ni5);
135  ni6 = _mm_packs_epi32(ni6, ni7);
136 
137  ni0 = _mm_packs_epi16(ni0, ni2);
138  ni4 = _mm_packs_epi16(ni4, ni6);
139 
140  _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
141  _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
142 #endif
143  }
144 #else
145  for (int i = 0; i < nb; i++) {
146  const float *xb = x + i * QK8_0;
147 
148  /* Find max absolute value in block */
149  float amax = 0.0f;
150  for (int j = 0; j < QK8_0; j++) {
151  float av = xb[j] >= 0 ? xb[j] : -xb[j];
152  if (av > amax) amax = av;
153  }
154 
155  /* Compute scale: d = max / 127 */
156  float d = amax / 127.0f;
157  float id = d != 0.0f ? 127.0f / amax : 0.0f;
158 
159  /* Store scale as FP16 */
160  y[i].d = CK_FP32_TO_FP16(d);
161 
162  /* Quantize values */
163  for (int j = 0; j < QK8_0; j++) {
164  float v = xb[j] * id;
165  /* Round to nearest int and clamp to [-127, 127] */
166  int q = (int)(v + (v >= 0 ? 0.5f : -0.5f));
167  if (q > 127) q = 127;
168  if (q < -127) q = -127;
169  y[i].qs[j] = (int8_t)q;
170  }
171  }
172 #endif
173 }
#define CK_FP32_TO_FP16(x)
int32_t id
Definition: tokenizer.h:315

References CK_FP32_TO_FP16, block_q8_0::d, id, QK8_0, and block_q8_0::qs.

Referenced by fused_mlp_swiglu_prefill_w1w2_quant(), fused_rmsnorm_qkv_prefill_head_major_quant(), and quantize_attn_out_head_major_q8_0().

◆ quantize_row_q8_k()

void quantize_row_q8_k ( const float *  x,
void *  y,
int  k 
)

Definition at line 107 of file gemm_kernels_q4k_q8k.c.

107  {
108 #if defined(__SSE4_1__)
109  quantize_row_q8_k_sse(x, vy, k);
110 #else
111  quantize_row_q8_k_ref(x, vy, k);
112 #endif
113 }
void quantize_row_q8_k_sse(const float *x, void *vy, int k)
void quantize_row_q8_k_ref(const float *x, void *vy, int k)

References quantize_row_q8_k_ref(), and quantize_row_q8_k_sse().

Referenced by ck_attention_project_head_major_q4_k_q8_k(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_qkv_project_head_major_q4_k_q8_k(), decode_layer_parallel(), fused_mlp_swiglu_prefill_w1w2_quant(), fused_rmsnorm_qkv_prefill_head_major_quant(), mlp_parallel(), model_decode_token(), model_forward_prefill_impl(), model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_forward_prefill_impl(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_9_decode(), and unfused_rmsnorm_linear_q4k_ref().

◆ relu_backward()

void relu_backward ( const float *  input,
const float *  d_output,
float *  d_input,
size_t  n 
)

Definition at line 84 of file relu_kernels.c.

88 {
89  size_t i = 0;
90 
91 #if defined(__AVX512F__)
92  __m512 vzero = _mm512_setzero_ps();
93  for (; i + 15 < n; i += 16) {
94  __m512 vx = _mm512_loadu_ps(input + i);
95  __m512 vdy = _mm512_loadu_ps(d_output + i);
96  __mmask16 mask = _mm512_cmp_ps_mask(vx, vzero, _CMP_GT_OQ);
97  __m512 vdx = _mm512_maskz_mov_ps(mask, vdy);
98  _mm512_storeu_ps(d_input + i, vdx);
99  }
100 #elif defined(__AVX2__) || defined(__AVX__)
101  __m256 vzero = _mm256_setzero_ps();
102  for (; i + 7 < n; i += 8) {
103  __m256 vx = _mm256_loadu_ps(input + i);
104  __m256 vdy = _mm256_loadu_ps(d_output + i);
105  // Result is all 1s (0xFFFFFFFF) if true, 0 if false.
106  __m256 mask = _mm256_cmp_ps(vx, vzero, _CMP_GT_OQ);
107  __m256 vdx = _mm256_and_ps(mask, vdy);
108  _mm256_storeu_ps(d_input + i, vdx);
109  }
110 #endif
111 
112  // Scalar fallback
113  for (; i < n; ++i) {
114  d_input[i] = (input[i] > 0.0f) ? d_output[i] : 0.0f;
115  }
116 }
int32_t int32_t int32_t int32_t int32_t mask
Definition: tokenizer.h:233

References mask.

◆ relu_backward_bf16()

void relu_backward_bf16 ( const uint16_t *  input,
const uint16_t *  d_output,
uint16_t *  d_input,
size_t  n 
)

Definition at line 45 of file relu_kernels_bf16.c.

49 {
50  if (!input || !d_output || !d_input) {
51  return;
52  }
53  for (size_t i = 0; i < n; ++i) {
54  float x = bf16_to_float(input[i]);
55  float dy = bf16_to_float(d_output[i]);
56  d_input[i] = float_to_bf16(x > 0.0f ? dy : 0.0f);
57  }
58 }

References bf16_to_float(), and float_to_bf16().

◆ relu_forward()

void relu_forward ( const float *  input,
float *  output,
size_t  n 
)

Definition at line 26 of file relu_kernels.c.

27 {
28  size_t i = 0;
29 
30 #if defined(__AVX512F__)
31  __m512 vzero = _mm512_setzero_ps();
32  for (; i + 15 < n; i += 16) {
33  __m512 vx = _mm512_loadu_ps(input + i);
34  __m512 vy = _mm512_max_ps(vx, vzero);
35  _mm512_storeu_ps(output + i, vy);
36  }
37 #elif defined(__AVX2__) || defined(__AVX__)
38  __m256 vzero = _mm256_setzero_ps();
39  for (; i + 7 < n; i += 8) {
40  __m256 vx = _mm256_loadu_ps(input + i);
41  __m256 vy = _mm256_max_ps(vx, vzero);
42  _mm256_storeu_ps(output + i, vy);
43  }
44 #endif
45 
46  // Scalar fallback
47  for (; i < n; ++i) {
48  float x = input[i];
49  output[i] = (x > 0.0f) ? x : 0.0f;
50  }
51 }

◆ relu_forward_bf16()

void relu_forward_bf16 ( const uint16_t *  input,
uint16_t *  output,
size_t  n 
)

Definition at line 23 of file relu_kernels_bf16.c.

24 {
25  if (!input || !output) {
26  return;
27  }
28  for (size_t i = 0; i < n; ++i) {
29  float x = bf16_to_float(input[i]);
30  output[i] = float_to_bf16(x > 0.0f ? x : 0.0f);
31  }
32 }

References bf16_to_float(), and float_to_bf16().

◆ relu_forward_inplace()

void relu_forward_inplace ( float *  data,
size_t  n 
)

Definition at line 54 of file relu_kernels.c.

55 {
56  size_t i = 0;
57 
58 #if defined(__AVX512F__)
59  __m512 vzero = _mm512_setzero_ps();
60  for (; i + 15 < n; i += 16) {
61  __m512 vx = _mm512_loadu_ps(data + i);
62  __m512 vy = _mm512_max_ps(vx, vzero);
63  _mm512_storeu_ps(data + i, vy);
64  }
65 #elif defined(__AVX2__) || defined(__AVX__)
66  __m256 vzero = _mm256_setzero_ps();
67  for (; i + 7 < n; i += 8) {
68  __m256 vx = _mm256_loadu_ps(data + i);
69  __m256 vy = _mm256_max_ps(vx, vzero);
70  _mm256_storeu_ps(data + i, vy);
71  }
72 #endif
73 
74  // Scalar fallback
75  for (; i < n; ++i) {
76  float x = data[i];
77  if (x < 0.0f) {
78  data[i] = 0.0f;
79  }
80  }
81 }

◆ relu_forward_inplace_bf16()

void relu_forward_inplace_bf16 ( uint16_t *  data,
size_t  n 
)

Definition at line 34 of file relu_kernels_bf16.c.

35 {
36  if (!data) {
37  return;
38  }
39  for (size_t i = 0; i < n; ++i) {
40  float x = bf16_to_float(data[i]);
41  data[i] = float_to_bf16(x > 0.0f ? x : 0.0f);
42  }
43 }

References bf16_to_float(), and float_to_bf16().

◆ rmsnorm_backward()

void rmsnorm_backward ( const float *  d_output,
const float *  input,
const float *  gamma,
const float *  rstd_cache,
float *  d_input,
float *  d_gamma,
int  tokens,
int  d_model,
int  aligned_embed_dim 
)

RMSNorm backward pass

Test:

test_rmsnorm.py::TestRMSNormBackward::test_backward_tokens

test_rmsnorm.py::TestRMSNormBackward::test_backward_single

test_parity.py::test_rmsnorm_backward_parity

Computes dX and dGamma given dY, X, gamma, and cached rstd. dX_i = rstd * (dY_i * gamma_i - x_hat_i * m) dGamma_i = sum_t (dY_i * x_hat_i)

After changes: make test && make llamacpp-parity-full

Definition at line 184 of file rmsnorm_kernels.c.

193 {
194  int T = tokens;
195  int D = d_model;
196  int aligned = aligned_embed_dim;
197 
198  // Zero parameter gradients
199 #if defined(__AVX512F__)
200  {
201  int d = 0;
202  for (; d + 16 <= D; d += 16) {
203  _mm512_storeu_ps(&d_gamma[d], _mm512_setzero_ps());
204  }
205  for (; d < D; ++d) {
206  d_gamma[d] = 0.0f;
207  }
208  }
209 #elif defined(__AVX__)
210  {
211  int d = 0;
212  for (; d + 8 <= D; d += 8) {
213  _mm256_storeu_ps(&d_gamma[d], _mm256_setzero_ps());
214  }
215  for (; d < D; ++d) {
216  d_gamma[d] = 0.0f;
217  }
218  }
219 #else
220  for (int d = 0; d < D; ++d) {
221  d_gamma[d] = 0.0f;
222  }
223 #endif
224 
225  for (int t = 0; t < T; ++t) {
226  const float *x = input + (size_t)t * aligned;
227  const float *dY = d_output + (size_t)t * aligned;
228  float *dX = d_input + (size_t)t * aligned;
229 
230  float rstd = rstd_cache[t];
231 
232 #if defined(__AVX512F__)
233  // Compute m = (1/D) * sum_j (dY_j * gamma_j * x_hat_j)
234  __m512 rstd_vec = _mm512_set1_ps(rstd);
235  __m512 sum_vec = _mm512_setzero_ps();
236  int d = 0;
237 
238  for (; d + 16 <= D; d += 16) {
239  __m512 xv = _mm512_loadu_ps(&x[d]);
240  __m512 dyv = _mm512_loadu_ps(&dY[d]);
241  __m512 gv = _mm512_loadu_ps(&gamma[d]);
242  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
243  // sum += dY * gamma * x_hat
244  __m512 prod = _mm512_mul_ps(dyv, gv);
245  sum_vec = _mm512_fmadd_ps(prod, x_hat, sum_vec);
246  }
247  float sum_dY_g_xhat = _mm512_reduce_add_ps(sum_vec);
248 
249  // Handle remaining elements
250  for (; d < D; ++d) {
251  float x_hat = x[d] * rstd;
252  sum_dY_g_xhat += dY[d] * gamma[d] * x_hat;
253  }
254  float m = sum_dY_g_xhat / (float)D;
255 
256  // Compute dX and accumulate dGamma (vectorized)
257  __m512 m_vec = _mm512_set1_ps(m);
258  d = 0;
259  for (; d + 16 <= D; d += 16) {
260  __m512 xv = _mm512_loadu_ps(&x[d]);
261  __m512 dyv = _mm512_loadu_ps(&dY[d]);
262  __m512 gv = _mm512_loadu_ps(&gamma[d]);
263  __m512 dgv = _mm512_loadu_ps(&d_gamma[d]);
264 
265  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
266 
267  // dX = rstd * (dY * gamma - x_hat * m)
268  __m512 dy_g = _mm512_mul_ps(dyv, gv);
269  __m512 xhat_m = _mm512_mul_ps(x_hat, m_vec);
270  __m512 diff = _mm512_sub_ps(dy_g, xhat_m);
271  __m512 dxv = _mm512_mul_ps(rstd_vec, diff);
272  _mm512_storeu_ps(&dX[d], dxv);
273 
274  // d_gamma += dY * x_hat
275  dgv = _mm512_fmadd_ps(dyv, x_hat, dgv);
276  _mm512_storeu_ps(&d_gamma[d], dgv);
277  }
278  // Handle remaining elements
279  for (; d < D; ++d) {
280  float x_hat = x[d] * rstd;
281  float dy = dY[d];
282  dX[d] = rstd * (dy * gamma[d] - x_hat * m);
283  d_gamma[d] += dy * x_hat;
284  }
285 
286 #elif defined(__AVX__)
287  // Compute m = (1/D) * sum_j (dY_j * gamma_j * x_hat_j)
288  __m256 rstd_vec = _mm256_set1_ps(rstd);
289  __m256 sum_vec = _mm256_setzero_ps();
290  int d = 0;
291 
292  for (; d + 8 <= D; d += 8) {
293  __m256 xv = _mm256_loadu_ps(&x[d]);
294  __m256 dyv = _mm256_loadu_ps(&dY[d]);
295  __m256 gv = _mm256_loadu_ps(&gamma[d]);
296  __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
297  // sum += dY * gamma * x_hat (no FMA, use mul + mul + add)
298  __m256 prod = _mm256_mul_ps(dyv, gv);
299  __m256 prod2 = _mm256_mul_ps(prod, x_hat);
300  sum_vec = _mm256_add_ps(sum_vec, prod2);
301  }
302  float sum_dY_g_xhat = hsum256_ps_rmsnorm(sum_vec);
303 
304  // Handle remaining elements
305  for (; d < D; ++d) {
306  float x_hat = x[d] * rstd;
307  sum_dY_g_xhat += dY[d] * gamma[d] * x_hat;
308  }
309  float m = sum_dY_g_xhat / (float)D;
310 
311  // Compute dX and accumulate dGamma (vectorized)
312  __m256 m_vec = _mm256_set1_ps(m);
313  d = 0;
314  for (; d + 8 <= D; d += 8) {
315  __m256 xv = _mm256_loadu_ps(&x[d]);
316  __m256 dyv = _mm256_loadu_ps(&dY[d]);
317  __m256 gv = _mm256_loadu_ps(&gamma[d]);
318  __m256 dgv = _mm256_loadu_ps(&d_gamma[d]);
319 
320  __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
321 
322  // dX = rstd * (dY * gamma - x_hat * m)
323  __m256 dy_g = _mm256_mul_ps(dyv, gv);
324  __m256 xhat_m = _mm256_mul_ps(x_hat, m_vec);
325  __m256 diff = _mm256_sub_ps(dy_g, xhat_m);
326  __m256 dxv = _mm256_mul_ps(rstd_vec, diff);
327  _mm256_storeu_ps(&dX[d], dxv);
328 
329  // d_gamma += dY * x_hat
330  __m256 dy_xhat = _mm256_mul_ps(dyv, x_hat);
331  dgv = _mm256_add_ps(dgv, dy_xhat);
332  _mm256_storeu_ps(&d_gamma[d], dgv);
333  }
334  // Handle remaining elements
335  for (; d < D; ++d) {
336  float x_hat = x[d] * rstd;
337  float dy = dY[d];
338  dX[d] = rstd * (dy * gamma[d] - x_hat * m);
339  d_gamma[d] += dy * x_hat;
340  }
341 
342 #else
343  // Scalar fallback
344  // Compute m = (1/D) * sum_j (dY_j * gamma_j * x_hat_j)
345  double sum_dY_g_xhat = 0.0;
346  for (int d = 0; d < D; ++d) {
347  float x_hat = x[d] * rstd;
348  sum_dY_g_xhat += (double)dY[d] * (double)gamma[d] * (double)x_hat;
349  }
350  float m = (float)(sum_dY_g_xhat / (double)D);
351 
352  // Compute dX and accumulate dGamma
353  for (int d = 0; d < D; ++d) {
354  float x_hat = x[d] * rstd;
355  float dy = dY[d];
356  dX[d] = rstd * (dy * gamma[d] - x_hat * m);
357  d_gamma[d] += dy * x_hat;
358  }
359 #endif
360 
361  // Zero padding gradients (if any)
362  for (int d = D; d < aligned; ++d) {
363  dX[d] = 0.0f;
364  }
365  }
366 }

Referenced by ck_layer_backward_rmsnorm_swiglu(), rmsnorm_backward_int4(), and rmsnorm_backward_int8().

◆ rmsnorm_backward_bf16()

void rmsnorm_backward_bf16 ( const uint16_t *  d_output,
const uint16_t *  input,
const float *  gamma,
const float *  rstd_cache,
uint16_t *  d_input,
float *  d_gamma,
int  tokens,
int  d_model,
int  aligned_embed_dim 
)

Definition at line 113 of file rmsnorm_kernels_bf16.c.

122 {
123  int T = tokens;
124  int D = d_model;
125  int aligned = aligned_embed_dim;
126 
127  if (!d_output || !input || !gamma || !rstd_cache || !d_input || !d_gamma) {
128  return;
129  }
130 
131  // Zero parameter gradients
132 #if defined(__AVX512F__)
133  {
134  int d = 0;
135  for (; d + 16 <= D; d += 16) {
136  _mm512_storeu_ps(&d_gamma[d], _mm512_setzero_ps());
137  }
138  for (; d < D; ++d) {
139  d_gamma[d] = 0.0f;
140  }
141  }
142 #else
143  for (int d = 0; d < D; ++d) {
144  d_gamma[d] = 0.0f;
145  }
146 #endif
147 
148  for (int t = 0; t < T; ++t) {
149  const uint16_t *x_bf16 = input + (size_t)t * aligned;
150  const uint16_t *dY_bf16 = d_output + (size_t)t * aligned;
151  uint16_t *dX_bf16 = d_input + (size_t)t * aligned;
152  float rstd = rstd_cache[t];
153 
154 #if defined(__AVX512F__)
155  // Compute m = (1/D) * sum_j (dY_j * gamma_j * x_hat_j)
156  __m512 rstd_vec = _mm512_set1_ps(rstd);
157  __m512 sum_vec = _mm512_setzero_ps();
158  int d = 0;
159 
160  for (; d + 16 <= D; d += 16) {
161  __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
162  __m512 dyv = bf16_loadu_cvt_fp32(&dY_bf16[d]);
163  __m512 gv = _mm512_loadu_ps(&gamma[d]);
164  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
165  // sum += dY * gamma * x_hat
166  __m512 prod = _mm512_mul_ps(dyv, gv);
167  sum_vec = _mm512_fmadd_ps(prod, x_hat, sum_vec);
168  }
169  float sum_dY_g_xhat = _mm512_reduce_add_ps(sum_vec);
170 
171  // Handle remaining elements
172  for (; d < D; ++d) {
173  float x = bf16_to_float(x_bf16[d]);
174  float x_hat = x * rstd;
175  float dy = bf16_to_float(dY_bf16[d]);
176  sum_dY_g_xhat += dy * gamma[d] * x_hat;
177  }
178  float m = sum_dY_g_xhat / (float)D;
179 
180  // Compute dX and accumulate dGamma (vectorized)
181  __m512 m_vec = _mm512_set1_ps(m);
182  d = 0;
183  for (; d + 16 <= D; d += 16) {
184  __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
185  __m512 dyv = bf16_loadu_cvt_fp32(&dY_bf16[d]);
186  __m512 gv = _mm512_loadu_ps(&gamma[d]);
187  __m512 dgv = _mm512_loadu_ps(&d_gamma[d]);
188 
189  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
190 
191  // dX = rstd * (dY * gamma - x_hat * m)
192  __m512 dy_g = _mm512_mul_ps(dyv, gv);
193  __m512 xhat_m = _mm512_mul_ps(x_hat, m_vec);
194  __m512 diff = _mm512_sub_ps(dy_g, xhat_m);
195  __m512 dxv = _mm512_mul_ps(rstd_vec, diff);
196  fp32_cvt_storeu_bf16(&dX_bf16[d], dxv);
197 
198  // d_gamma += dY * x_hat
199  dgv = _mm512_fmadd_ps(dyv, x_hat, dgv);
200  _mm512_storeu_ps(&d_gamma[d], dgv);
201  }
202  // Handle remaining elements
203  for (; d < D; ++d) {
204  float x = bf16_to_float(x_bf16[d]);
205  float x_hat = x * rstd;
206  float dy = bf16_to_float(dY_bf16[d]);
207  float dx = rstd * (dy * gamma[d] - x_hat * m);
208  dX_bf16[d] = float_to_bf16(dx);
209  d_gamma[d] += dy * x_hat;
210  }
211 
212 #else
213  // Scalar fallback
214  double sum_dY_g_xhat = 0.0;
215  for (int d = 0; d < D; ++d) {
216  float x = bf16_to_float(x_bf16[d]);
217  float x_hat = x * rstd;
218  float dy = bf16_to_float(dY_bf16[d]);
219  sum_dY_g_xhat += (double)dy * (double)gamma[d] * (double)x_hat;
220  }
221  float m = (float)(sum_dY_g_xhat / (double)D);
222 
223  for (int d = 0; d < D; ++d) {
224  float x = bf16_to_float(x_bf16[d]);
225  float x_hat = x * rstd;
226  float dy = bf16_to_float(dY_bf16[d]);
227  float dx = rstd * (dy * gamma[d] - x_hat * m);
228  dX_bf16[d] = float_to_bf16(dx);
229  d_gamma[d] += dy * x_hat;
230  }
231 #endif
232 
233  // Zero padding gradients
234  for (int d = D; d < aligned; ++d) {
235  dX_bf16[d] = 0;
236  }
237  }
238 }

References bf16_to_float(), and float_to_bf16().

◆ rmsnorm_backward_int4()

void rmsnorm_backward_int4 ( const uint8_t *  d_output,
const uint8_t *  input,
const float *  gamma,
const float *  rstd_cache,
uint8_t *  d_input,
float *  d_gamma,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float *  scratch_d_output,
float *  scratch_input,
float *  scratch_d_input 
)

Definition at line 104 of file rmsnorm_kernels_int4.c.

116 {
117  if (!d_output || !input || !gamma || !rstd_cache || !d_input || !d_gamma) return;
118  if (!scratch_d_output || !scratch_input || !scratch_d_input) return;
119 
120  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
121 
122  convert_int4_to_float(d_output, scratch_d_output, total);
123  convert_int4_to_float(input, scratch_input, total);
124 
125  for (int d = 0; d < d_model; ++d) {
126  d_gamma[d] = 0.0f;
127  }
128 
129  rmsnorm_backward(scratch_d_output,
130  scratch_input,
131  gamma,
132  rstd_cache,
133  scratch_d_input,
134  d_gamma,
135  tokens,
136  d_model,
137  aligned_embed_dim);
138 
139  convert_float_to_int4(scratch_d_input, d_input, total);
140 }
void rmsnorm_backward(const float *d_output, const float *input, const float *gamma, const float *rstd_cache, float *d_input, float *d_gamma, int tokens, int d_model, int aligned_embed_dim)
static void convert_int4_to_float(const uint8_t *src, float *dst, size_t count)
static void convert_float_to_int4(const float *src, uint8_t *dst, size_t count)

References convert_float_to_int4(), convert_int4_to_float(), and rmsnorm_backward().

◆ rmsnorm_backward_int8()

void rmsnorm_backward_int8 ( const int8_t *  d_output,
const int8_t *  input,
const float *  gamma,
const float *  rstd_cache,
int8_t *  d_input,
float *  d_gamma,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float *  scratch_d_output,
float *  scratch_input,
float *  scratch_d_input 
)

Definition at line 84 of file rmsnorm_kernels_int8.c.

96 {
97  if (!d_output || !input || !gamma || !rstd_cache || !d_input || !d_gamma) return;
98  if (!scratch_d_output || !scratch_input || !scratch_d_input) return;
99 
100  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
101 
102  convert_int8_to_float(d_output, scratch_d_output, total);
103  convert_int8_to_float(input, scratch_input, total);
104 
105  // Zero gamma gradient before accumulation.
106  for (int d = 0; d < d_model; ++d) {
107  d_gamma[d] = 0.0f;
108  }
109 
110  rmsnorm_backward(scratch_d_output,
111  scratch_input,
112  gamma,
113  rstd_cache,
114  scratch_d_input,
115  d_gamma,
116  tokens,
117  d_model,
118  aligned_embed_dim);
119 
120  convert_float_to_int8(scratch_d_input, d_input, total);
121 }
static void convert_int8_to_float(const int8_t *src, float *dst, size_t count)
static void convert_float_to_int8(const float *src, int8_t *dst, size_t count)

References convert_float_to_int8(), convert_int8_to_float(), and rmsnorm_backward().

◆ rmsnorm_forward()

void rmsnorm_forward ( const float *  input,
const float *  gamma,
float *  output,
float *  rstd_cache,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps 
)

RMSNorm forward pass

Test:

test_rmsnorm.py::TestRMSNormForward::test_fp32_tokens

test_rmsnorm.py::TestRMSNormForward::test_fp32_single

test_rmsnorm.py::TestRMSNormForward::test_perf_rolled

test_layernorm.py::TestLayerNormForward::test_rmsnorm_compat

test_parity.py::test_rmsnorm_parity

RMSNorm: y[i] = gamma[i] * x[i] / sqrt(mean(x^2) + eps)

After changes: make test && make llamacpp-parity-full

Definition at line 50 of file rmsnorm_kernels.c.

58 {
59  int T = tokens;
60  int D = d_model;
61  int aligned = aligned_embed_dim;
62 
63  for (int t = 0; t < T; ++t) {
64  const float *x = input + (size_t)t * aligned;
65  float *y = output + (size_t)t * aligned;
66 
67 #if defined(__AVX512F__)
68  // AVX-512: Process 16 floats at a time
69  __m512 sum_sq_vec = _mm512_setzero_ps();
70  int d = 0;
71 
72  // Vectorized sum of squares
73  for (; d + 16 <= D; d += 16) {
74  __m512 xv = _mm512_loadu_ps(&x[d]);
75  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
76  }
77  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
78 
79  // Handle remaining elements
80  for (; d < D; ++d) {
81  sum_sq += x[d] * x[d];
82  }
83 
84  float mean_sq = sum_sq / (float)D;
85  float rstd = 1.0f / sqrtf(mean_sq + eps);
86  if (rstd_cache) {
87  rstd_cache[t] = rstd;
88  }
89 
90  // Apply normalization and scale (vectorized)
91  __m512 rstd_vec = _mm512_set1_ps(rstd);
92  d = 0;
93  for (; d + 16 <= D; d += 16) {
94  __m512 xv = _mm512_loadu_ps(&x[d]);
95  __m512 gv = _mm512_loadu_ps(&gamma[d]);
96  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
97  __m512 yv = _mm512_mul_ps(x_hat, gv);
98  _mm512_storeu_ps(&y[d], yv);
99  }
100  // Handle remaining elements
101  for (; d < D; ++d) {
102  y[d] = x[d] * rstd * gamma[d];
103  }
104 
105 #elif defined(__AVX__)
106  // AVX: Process 8 floats at a time
107  __m256 sum_sq_vec = _mm256_setzero_ps();
108  int d = 0;
109 
110  // Vectorized sum of squares (no FMA in AVX1, use mul + add)
111  for (; d + 8 <= D; d += 8) {
112  __m256 xv = _mm256_loadu_ps(&x[d]);
113  __m256 xv_sq = _mm256_mul_ps(xv, xv);
114  sum_sq_vec = _mm256_add_ps(sum_sq_vec, xv_sq);
115  }
116  float sum_sq = hsum256_ps_rmsnorm(sum_sq_vec);
117 
118  // Handle remaining elements
119  for (; d < D; ++d) {
120  sum_sq += x[d] * x[d];
121  }
122 
123  float mean_sq = sum_sq / (float)D;
124  float rstd = 1.0f / sqrtf(mean_sq + eps);
125  if (rstd_cache) {
126  rstd_cache[t] = rstd;
127  }
128 
129  // Apply normalization and scale (vectorized)
130  __m256 rstd_vec = _mm256_set1_ps(rstd);
131  d = 0;
132  for (; d + 8 <= D; d += 8) {
133  __m256 xv = _mm256_loadu_ps(&x[d]);
134  __m256 gv = _mm256_loadu_ps(&gamma[d]);
135  __m256 x_hat = _mm256_mul_ps(xv, rstd_vec);
136  __m256 yv = _mm256_mul_ps(x_hat, gv);
137  _mm256_storeu_ps(&y[d], yv);
138  }
139  // Handle remaining elements
140  for (; d < D; ++d) {
141  y[d] = x[d] * rstd * gamma[d];
142  }
143 
144 #else
145  // Scalar fallback
146  double sum_sq = 0.0;
147  for (int d = 0; d < D; ++d) {
148  double v = (double)x[d];
149  sum_sq += v * v;
150  }
151  double mean_sq = sum_sq / (double)D;
152  double r = sqrt(mean_sq + (double)eps);
153  float rstd = (float)(1.0 / r);
154  if (rstd_cache) {
155  rstd_cache[t] = rstd;
156  }
157 
158  // Apply normalization and scale
159  for (int d = 0; d < D; ++d) {
160  float x_hat = x[d] * rstd;
161  y[d] = x_hat * gamma[d];
162  }
163 #endif
164 
165  // Zero padding (if any)
166  for (int d = D; d < aligned; ++d) {
167  y[d] = 0.0f;
168  }
169  }
170 }

Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), ck_layer_forward_rmsnorm_swiglu_ref(), mega_fused_outproj_mlp_prefill(), model_decode_token(), model_forward_prefill_impl(), model_layer_0_decode(), model_layer_0_prefill(), model_layer_10_decode(), model_layer_10_prefill(), model_layer_11_decode(), model_layer_11_prefill(), model_layer_12_decode(), model_layer_12_prefill(), model_layer_13_decode(), model_layer_13_prefill(), model_layer_14_decode(), model_layer_14_prefill(), model_layer_15_decode(), model_layer_15_prefill(), model_layer_16_decode(), model_layer_16_prefill(), model_layer_17_decode(), model_layer_17_prefill(), model_layer_18_decode(), model_layer_18_prefill(), model_layer_19_decode(), model_layer_19_prefill(), model_layer_1_decode(), model_layer_1_prefill(), model_layer_20_decode(), model_layer_20_prefill(), model_layer_21_decode(), model_layer_21_prefill(), model_layer_22_decode(), model_layer_22_prefill(), model_layer_23_decode(), model_layer_23_prefill(), model_layer_2_decode(), model_layer_2_prefill(), model_layer_3_decode(), model_layer_3_prefill(), model_layer_4_decode(), model_layer_4_prefill(), model_layer_5_decode(), model_layer_5_prefill(), model_layer_6_decode(), model_layer_6_prefill(), model_layer_7_decode(), model_layer_7_prefill(), model_layer_8_decode(), model_layer_8_prefill(), model_layer_9_decode(), model_layer_9_prefill(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_forward_prefill_impl(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), qwen2_0_5b_decode_layer_9_prefill(), rmsnorm_forward_int4(), and rmsnorm_forward_int8().

◆ rmsnorm_forward_bf16()

void rmsnorm_forward_bf16 ( const uint16_t *  input,
const float *  gamma,
uint16_t *  output,
float *  rstd_cache,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps 
)

Definition at line 24 of file rmsnorm_kernels_bf16.c.

32 {
33  int T = tokens;
34  int D = d_model;
35  int aligned = aligned_embed_dim;
36 
37  for (int t = 0; t < T; ++t) {
38  const uint16_t *x_bf16 = input + (size_t)t * aligned;
39  float *rstd_ptr = rstd_cache ? (rstd_cache + t) : NULL;
40  uint16_t *out_bf16 = output + (size_t)t * aligned;
41 
42 #if defined(__AVX512F__)
43  // AVX-512: Process 16 floats at a time
44  __m512 sum_sq_vec = _mm512_setzero_ps();
45  int d = 0;
46 
47  // Vectorized sum of squares
48  for (; d + 16 <= D; d += 16) {
49  __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
50  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
51  }
52  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
53 
54  // Handle remaining elements
55  for (; d < D; ++d) {
56  float x = bf16_to_float(x_bf16[d]);
57  sum_sq += x * x;
58  }
59 
60  float mean_sq = sum_sq / (float)D;
61  float rstd = 1.0f / sqrtf(mean_sq + eps);
62  if (rstd_ptr) {
63  *rstd_ptr = rstd;
64  }
65 
66  // Apply normalization and scale (vectorized)
67  __m512 rstd_vec = _mm512_set1_ps(rstd);
68  d = 0;
69  for (; d + 16 <= D; d += 16) {
70  __m512 xv = bf16_loadu_cvt_fp32(&x_bf16[d]);
71  __m512 gv = _mm512_loadu_ps(&gamma[d]);
72  __m512 x_hat = _mm512_mul_ps(xv, rstd_vec);
73  __m512 yv = _mm512_mul_ps(x_hat, gv);
74  fp32_cvt_storeu_bf16(&out_bf16[d], yv);
75  }
76  // Handle remaining elements
77  for (; d < D; ++d) {
78  float x = bf16_to_float(x_bf16[d]);
79  float y = x * rstd * gamma[d];
80  out_bf16[d] = float_to_bf16(y);
81  }
82 
83 #else
84  // Scalar fallback
85  double sum_sq = 0.0;
86  for (int d = 0; d < D; ++d) {
87  float x = bf16_to_float(x_bf16[d]);
88  sum_sq += (double)x * (double)x;
89  }
90  double mean_sq = sum_sq / (double)D;
91  double r = sqrt(mean_sq + (double)eps);
92  float rstd = (float)(1.0 / r);
93  if (rstd_ptr) {
94  *rstd_ptr = rstd;
95  }
96 
97  for (int d = 0; d < D; ++d) {
98  float x = bf16_to_float(x_bf16[d]);
99  float x_hat = x * rstd;
100  float y = x_hat * gamma[d];
101  out_bf16[d] = float_to_bf16(y);
102  }
103 #endif
104 
105  // Zero padding
106  for (int d = D; d < aligned; ++d) {
107  out_bf16[d] = 0;
108  }
109  }
110 }

References bf16_to_float(), and float_to_bf16().

◆ rmsnorm_forward_int4()

void rmsnorm_forward_int4 ( const uint8_t *  input,
const float *  gamma,
uint8_t *  output,
float *  rstd_cache,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps,
float *  scratch_input,
float *  scratch_output 
)

Definition at line 78 of file rmsnorm_kernels_int4.c.

88 {
89  if (!input || !gamma || !output) return;
90  if (!scratch_input || !scratch_output) return;
91 
92  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
93 
94  convert_int4_to_float(input, scratch_input, total);
95  rmsnorm_forward(scratch_input, gamma, scratch_output, rstd_cache,
96  tokens, d_model, aligned_embed_dim, eps);
97  convert_float_to_int4(scratch_output, output, total);
98 }
void rmsnorm_forward(const float *input, const float *gamma, float *output, float *rstd_cache, int tokens, int d_model, int aligned_embed_dim, float eps)

References convert_float_to_int4(), convert_int4_to_float(), and rmsnorm_forward().

◆ rmsnorm_forward_int8()

void rmsnorm_forward_int8 ( const int8_t *  input,
const float *  gamma,
int8_t *  output,
float *  rstd_cache,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps,
float *  scratch_input,
float *  scratch_output 
)

Definition at line 58 of file rmsnorm_kernels_int8.c.

68 {
69  if (!input || !gamma || !output) return;
70  if (!scratch_input || !scratch_output) return;
71 
72  size_t total = (size_t)tokens * (size_t)aligned_embed_dim;
73 
74  convert_int8_to_float(input, scratch_input, total);
75  rmsnorm_forward(scratch_input, gamma, scratch_output, rstd_cache,
76  tokens, d_model, aligned_embed_dim, eps);
77  convert_float_to_int8(scratch_output, output, total);
78 }

References convert_float_to_int8(), convert_int8_to_float(), and rmsnorm_forward().

◆ rope_backward()

void rope_backward ( const float *  d_out,
float *  d_x,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset 
)

RoPE backward (inverse rotation)

Test:

test_rope.py::TestRoPEBackward::test_rope_backward

test_rope.py::TestRoPEBackward::test_rope_backward_vs_separate

RoPE backward: inverse rotation (rotate by -θ). Since cos(-θ) = cos(θ) and sin(-θ) = -sin(θ): d_x[2i] = d0 * c + d1 * s d_x[2i+1] = -d0 * s + d1 * c

After changes: make test

Definition at line 238 of file rope_kernels.c.

247 {
248  size_t head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
249  int half_dim = head_dim / 2;
250 
251  for (int h = 0; h < num_heads; ++h) {
252  for (int t = 0; t < num_tokens; ++t) {
253  int pos = pos_offset + t;
254  const float *cos_row = cos_cache + pos * half_dim;
255  const float *sin_row = sin_cache + pos * half_dim;
256 
257  size_t idx = h * head_stride + (size_t)t * (size_t)aligned_head_dim;
258  const float *d_out_row = d_out + idx;
259  float *d_x_row = d_x + idx;
260 
261 #if defined(__AVX512F__)
262  int i = 0;
263  for (; i + 16 <= half_dim; i += 16) {
264  __m512 d0 = _mm512_loadu_ps(&d_out_row[i]);
265  __m512 d1 = _mm512_loadu_ps(&d_out_row[i + half_dim]);
266  __m512 c = _mm512_loadu_ps(&cos_row[i]);
267  __m512 s = _mm512_loadu_ps(&sin_row[i]);
268 
269  // Inverse: d_x[i] = d0 * c + d1 * s
270  __m512 r0 = _mm512_fmadd_ps(d0, c, _mm512_mul_ps(d1, s));
271  // Inverse: d_x[i+half] = -d0 * s + d1 * c
272  __m512 r1 = _mm512_fmsub_ps(d1, c, _mm512_mul_ps(d0, s));
273 
274  _mm512_storeu_ps(&d_x_row[i], r0);
275  _mm512_storeu_ps(&d_x_row[i + half_dim], r1);
276  }
277  for (; i < half_dim; ++i) {
278  float d0 = d_out_row[i];
279  float d1 = d_out_row[i + half_dim];
280  float c = cos_row[i];
281  float s = sin_row[i];
282  d_x_row[i] = d0 * c + d1 * s;
283  d_x_row[i + half_dim] = -d0 * s + d1 * c;
284  }
285 
286 #elif defined(__AVX__)
287  int i = 0;
288  for (; i + 8 <= half_dim; i += 8) {
289  __m256 d0 = _mm256_loadu_ps(&d_out_row[i]);
290  __m256 d1 = _mm256_loadu_ps(&d_out_row[i + half_dim]);
291  __m256 c = _mm256_loadu_ps(&cos_row[i]);
292  __m256 s = _mm256_loadu_ps(&sin_row[i]);
293 
294  // Inverse: d_x[i] = d0 * c + d1 * s
295  __m256 d0c = _mm256_mul_ps(d0, c);
296  __m256 d1s = _mm256_mul_ps(d1, s);
297  __m256 r0 = _mm256_add_ps(d0c, d1s);
298 
299  // Inverse: d_x[i+half] = -d0 * s + d1 * c = d1 * c - d0 * s
300  __m256 d1c = _mm256_mul_ps(d1, c);
301  __m256 d0s = _mm256_mul_ps(d0, s);
302  __m256 r1 = _mm256_sub_ps(d1c, d0s);
303 
304  _mm256_storeu_ps(&d_x_row[i], r0);
305  _mm256_storeu_ps(&d_x_row[i + half_dim], r1);
306  }
307  for (; i < half_dim; ++i) {
308  float d0 = d_out_row[i];
309  float d1 = d_out_row[i + half_dim];
310  float c = cos_row[i];
311  float s = sin_row[i];
312  d_x_row[i] = d0 * c + d1 * s;
313  d_x_row[i + half_dim] = -d0 * s + d1 * c;
314  }
315 
316 #else
317  for (int i = 0; i < half_dim; ++i) {
318  float d0 = d_out_row[i];
319  float d1 = d_out_row[i + half_dim];
320  float c = cos_row[i];
321  float s = sin_row[i];
322 
323  // Inverse rotation: rotate by -θ
324  d_x_row[i] = d0 * c + d1 * s;
325  d_x_row[i + half_dim] = -d0 * s + d1 * c;
326  }
327 #endif
328 
329  for (int i = head_dim; i < aligned_head_dim; ++i) {
330  d_x_row[i] = 0.0f;
331  }
332  }
333  }
334 }

Referenced by rope_backward_bf16(), and rope_backward_qk().

◆ rope_backward_bf16()

void rope_backward_bf16 ( const uint16_t *  d_out,
uint16_t *  d_x,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset,
float *  scratch_d_out,
float *  scratch_d_x 
)

Definition at line 52 of file rope_kernels_bf16.c.

63 {
64  if (!scratch_d_out || !scratch_d_x) return;
65 
66  size_t total = (size_t)num_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
67 
68  bf16_tensor_to_float(d_out, scratch_d_out, total);
69  rope_backward(scratch_d_out, scratch_d_x, cos_cache, sin_cache,
70  num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
71  float_tensor_to_bf16(scratch_d_x, d_x, total);
72 }
void rope_backward(const float *d_out, float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:238

References bf16_tensor_to_float(), float_tensor_to_bf16(), and rope_backward().

Referenced by rope_backward_qk_bf16().

◆ rope_backward_inplace()

void rope_backward_inplace ( float *  d_x,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset 
)

RoPE backward in-place (overwrite with inverse rotation)

Test:
test_rope.py::TestRoPEBackward::test_rope_backward_inplace

In-place backward: overwrite d_out with inverse-rotated gradients. Useful when d_x == d_out is acceptable (saves memory).

After changes: make test

Definition at line 345 of file rope_kernels.c.

353 {
354  size_t head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
355  int half_dim = head_dim / 2;
356 
357  for (int h = 0; h < num_heads; ++h) {
358  for (int t = 0; t < num_tokens; ++t) {
359  int pos = pos_offset + t;
360  const float *cos_row = cos_cache + pos * half_dim;
361  const float *sin_row = sin_cache + pos * half_dim;
362 
363  float *d_row = d_x + h * head_stride + (size_t)t * (size_t)aligned_head_dim;
364 
365 #if defined(__AVX512F__)
366  int i = 0;
367  for (; i + 16 <= half_dim; i += 16) {
368  __m512 d0 = _mm512_loadu_ps(&d_row[i]);
369  __m512 d1 = _mm512_loadu_ps(&d_row[i + half_dim]);
370  __m512 c = _mm512_loadu_ps(&cos_row[i]);
371  __m512 s = _mm512_loadu_ps(&sin_row[i]);
372 
373  __m512 r0 = _mm512_fmadd_ps(d0, c, _mm512_mul_ps(d1, s));
374  __m512 r1 = _mm512_fmsub_ps(d1, c, _mm512_mul_ps(d0, s));
375 
376  _mm512_storeu_ps(&d_row[i], r0);
377  _mm512_storeu_ps(&d_row[i + half_dim], r1);
378  }
379  for (; i < half_dim; ++i) {
380  float d0 = d_row[i];
381  float d1 = d_row[i + half_dim];
382  float c = cos_row[i];
383  float s = sin_row[i];
384  d_row[i] = d0 * c + d1 * s;
385  d_row[i + half_dim] = -d0 * s + d1 * c;
386  }
387 
388 #elif defined(__AVX__)
389  int i = 0;
390  for (; i + 8 <= half_dim; i += 8) {
391  __m256 d0 = _mm256_loadu_ps(&d_row[i]);
392  __m256 d1 = _mm256_loadu_ps(&d_row[i + half_dim]);
393  __m256 c = _mm256_loadu_ps(&cos_row[i]);
394  __m256 s = _mm256_loadu_ps(&sin_row[i]);
395 
396  __m256 d0c = _mm256_mul_ps(d0, c);
397  __m256 d1s = _mm256_mul_ps(d1, s);
398  __m256 r0 = _mm256_add_ps(d0c, d1s);
399 
400  __m256 d1c = _mm256_mul_ps(d1, c);
401  __m256 d0s = _mm256_mul_ps(d0, s);
402  __m256 r1 = _mm256_sub_ps(d1c, d0s);
403 
404  _mm256_storeu_ps(&d_row[i], r0);
405  _mm256_storeu_ps(&d_row[i + half_dim], r1);
406  }
407  for (; i < half_dim; ++i) {
408  float d0 = d_row[i];
409  float d1 = d_row[i + half_dim];
410  float c = cos_row[i];
411  float s = sin_row[i];
412  d_row[i] = d0 * c + d1 * s;
413  d_row[i + half_dim] = -d0 * s + d1 * c;
414  }
415 
416 #else
417  for (int i = 0; i < half_dim; ++i) {
418  float d0 = d_row[i];
419  float d1 = d_row[i + half_dim];
420  float c = cos_row[i];
421  float s = sin_row[i];
422 
423  // Inverse rotation: rotate by -θ
424  d_row[i] = d0 * c + d1 * s;
425  d_row[i + half_dim] = -d0 * s + d1 * c;
426  }
427 #endif
428 
429  for (int i = head_dim; i < aligned_head_dim; ++i) {
430  d_row[i] = 0.0f;
431  }
432  }
433  }
434 }

◆ rope_backward_qk()

void rope_backward_qk ( const float *  d_q_out,
const float *  d_k_out,
float *  d_q,
float *  d_k,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset 
)

RoPE backward for both dQ and dK

Test:
test_rope.py::TestRoPEBackward::test_rope_backward_qk

Combined RoPE backward for both dQ and dK gradients.

After changes: make test

Definition at line 497 of file rope_kernels.c.

509 {
510  rope_backward(d_q_out, d_q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
511  rope_backward(d_k_out, d_k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
512 }
void rope_backward(const float *d_out, float *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:238

References rope_backward().

Referenced by ck_layer_backward_rmsnorm_swiglu().

◆ rope_backward_qk_bf16()

void rope_backward_qk_bf16 ( const uint16_t *  d_q_out,
const uint16_t *  d_k_out,
uint16_t *  d_q,
uint16_t *  d_k,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset,
float *  scratch_dq_out,
float *  scratch_dq,
float *  scratch_dk_out,
float *  scratch_dk 
)

Definition at line 103 of file rope_kernels_bf16.c.

119 {
120  if (!d_q_out || !d_k_out || !d_q || !d_k) return;
121 
122  rope_backward_bf16(d_q_out, d_q, cos_cache, sin_cache,
123  num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset,
124  scratch_dq_out, scratch_dq);
125  rope_backward_bf16(d_k_out, d_k, cos_cache, sin_cache,
126  num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset,
127  scratch_dk_out, scratch_dk);
128 }
void rope_backward_bf16(const uint16_t *d_out, uint16_t *d_x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch_d_out, float *scratch_d_x)

References rope_backward_bf16().

◆ rope_forward()

void rope_forward ( float *  x,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset 
)

RoPE forward (head-major layout, in-place)

Test:

test_rope.py::TestRoPEForward::test_rope_forward

test_rope.py::TestRoPEForward::test_rope_vs_separate

test_parity.py::test_rope_parity

Applies rotary position embeddings in-place to Q or K tensor. x: [num_heads, num_tokens, head_dim] head-major

After changes: make test && make llamacpp-parity-full

Definition at line 180 of file rope_kernels.c.

188 {
189  size_t head_stride = (size_t)num_tokens * (size_t)aligned_head_dim;
190 
191  for (int h = 0; h < num_heads; ++h) {
192  rope_apply_head(x + h * head_stride,
193  cos_cache, sin_cache,
194  num_tokens, head_dim, aligned_head_dim, pos_offset);
195  }
196 }
static void rope_apply_head(float *x, const float *cos_cache, const float *sin_cache, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:79

References rope_apply_head().

Referenced by model_layer_0_decode(), model_layer_10_decode(), model_layer_11_decode(), model_layer_12_decode(), model_layer_13_decode(), model_layer_14_decode(), model_layer_15_decode(), model_layer_16_decode(), model_layer_17_decode(), model_layer_18_decode(), model_layer_19_decode(), model_layer_1_decode(), model_layer_20_decode(), model_layer_21_decode(), model_layer_22_decode(), model_layer_23_decode(), model_layer_2_decode(), model_layer_3_decode(), model_layer_4_decode(), model_layer_5_decode(), model_layer_6_decode(), model_layer_7_decode(), model_layer_8_decode(), model_layer_9_decode(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_9_decode(), rope_forward_bf16(), and rope_forward_qk().

◆ rope_forward_bf16()

void rope_forward_bf16 ( uint16_t *  x,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset,
float *  scratch 
)

Definition at line 28 of file rope_kernels_bf16.c.

37 {
38  if (!scratch) return;
39 
40  size_t total = (size_t)num_heads * (size_t)num_tokens * (size_t)aligned_head_dim;
41 
42  bf16_tensor_to_float(x, scratch, total);
43  rope_forward(scratch, cos_cache, sin_cache,
44  num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
45  float_tensor_to_bf16(scratch, x, total);
46 }
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:180

References bf16_tensor_to_float(), float_tensor_to_bf16(), and rope_forward().

Referenced by rope_forward_qk_bf16().

◆ rope_forward_qk()

void rope_forward_qk ( float *  q,
float *  k,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset 
)

RoPE forward for both Q and K (common inference pattern)

Test:

test_rope.py::TestRoPEForward::test_rope_forward_qk

test_fused_attention_decode.py::TestFusedAttentionDecode::test_qk_rope

test_parity.py::test_rope_qk_parity

Combined RoPE forward for both Q and K in one call. q: [num_heads, num_tokens, head_dim] k: [num_kv_heads, num_tokens, head_dim]

After changes: make test && make llamacpp-parity-full

Definition at line 448 of file rope_kernels.c.

458 {
459  rope_forward(q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
460  rope_forward(k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset);
461 }
void rope_forward(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset)
Definition: rope_kernels.c:180

References rope_forward().

Referenced by ck_layer_forward_rmsnorm_swiglu(), ck_layer_forward_rmsnorm_swiglu_decode(), ck_layer_forward_rmsnorm_swiglu_decode_fused(), ck_layer_forward_rmsnorm_swiglu_decode_fused_attn_impl(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_layer_forward_rmsnorm_swiglu_decode_quant(), ck_layer_forward_rmsnorm_swiglu_q4_k(), ck_layer_forward_rmsnorm_swiglu_quant(), ck_layer_forward_rmsnorm_swiglu_ref(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().

◆ rope_forward_qk_bf16()

void rope_forward_qk_bf16 ( uint16_t *  q,
uint16_t *  k,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset,
float *  scratch_q,
float *  scratch_k 
)

Definition at line 79 of file rope_kernels_bf16.c.

91 {
92  if (!q || !k) return;
93 
94  rope_forward_bf16(q, cos_cache, sin_cache,
95  num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, scratch_q);
96  rope_forward_bf16(k, cos_cache, sin_cache,
97  num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, scratch_k);
98 }
void rope_forward_bf16(uint16_t *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, float *scratch)

References rope_forward_bf16().

◆ rope_forward_qk_strided()

void rope_forward_qk_strided ( float *  q,
float *  k,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_kv_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset,
int  q_stride_tokens,
int  k_stride_tokens 
)

RoPE forward for both Q and K with custom strides (KV cache layouts)

Test:

test_rope.py::TestRoPEForward::test_rope_forward_qk_strided

test_kv_cache_attention.py::TestKVCacheAttention::test_qk_rope_strided

Combined QK RoPE with configurable strides for KV cache layouts.

After changes: make test

Definition at line 472 of file rope_kernels.c.

484 {
485  rope_forward_strided(q, cos_cache, sin_cache, num_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, q_stride_tokens);
486  rope_forward_strided(k, cos_cache, sin_cache, num_kv_heads, num_tokens, head_dim, aligned_head_dim, pos_offset, k_stride_tokens);
487 }
void rope_forward_strided(float *x, const float *cos_cache, const float *sin_cache, int num_heads, int num_tokens, int head_dim, int aligned_head_dim, int pos_offset, int head_stride_tokens)
Definition: rope_kernels.c:207

References rope_forward_strided().

Referenced by mega_fused_attention_prefill(), mega_fused_attention_prefill_q8_0(), model_layer_0_prefill(), model_layer_10_prefill(), model_layer_11_prefill(), model_layer_12_prefill(), model_layer_13_prefill(), model_layer_14_prefill(), model_layer_15_prefill(), model_layer_16_prefill(), model_layer_17_prefill(), model_layer_18_prefill(), model_layer_19_prefill(), model_layer_1_prefill(), model_layer_20_prefill(), model_layer_21_prefill(), model_layer_22_prefill(), model_layer_23_prefill(), model_layer_2_prefill(), model_layer_3_prefill(), model_layer_4_prefill(), model_layer_5_prefill(), model_layer_6_prefill(), model_layer_7_prefill(), model_layer_8_prefill(), model_layer_9_prefill(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_prefill(), and qwen2_0_5b_decode_layer_9_prefill().

◆ rope_forward_strided()

void rope_forward_strided ( float *  x,
const float *  cos_cache,
const float *  sin_cache,
int  num_heads,
int  num_tokens,
int  head_dim,
int  aligned_head_dim,
int  pos_offset,
int  head_stride_tokens 
)

RoPE forward with custom head stride (for KV cache layouts)

Test:

test_rope.py::TestRoPEForward::test_rope_strided

test_kv_cache_attention.py::TestKVCacheAttention::test_rope_decode

Variant with configurable head_stride_tokens for non-contiguous head layouts.

After changes: make test

Definition at line 207 of file rope_kernels.c.

216 {
217  size_t head_stride = (size_t)head_stride_tokens * (size_t)aligned_head_dim;
218 
219  for (int h = 0; h < num_heads; ++h) {
220  rope_apply_head(x + h * head_stride,
221  cos_cache, sin_cache,
222  num_tokens, head_dim, aligned_head_dim, pos_offset);
223  }
224 }

References rope_apply_head().

Referenced by rope_forward_qk_strided().

◆ rope_precompute_cache()

void rope_precompute_cache ( float *  cos_cache,
float *  sin_cache,
int  max_seq_len,
int  head_dim,
float  base 
)

Precompute RoPE cos/sin cache

Test:

test_rope.py::TestRoPECache::test_cache_computation

test_rope.py::TestRoPECache::test_cache_values

Precomputes cos(m * theta_i) and sin(m * theta_i) for positions 0..max_seq_len-1. cos_cache, sin_cache: [max_seq_len, head_dim/2]

After changes: make test

Definition at line 52 of file rope_kernels.c.

57 {
58  int half_dim = head_dim / 2;
59 
60  long double base_ld = (long double)base;
61  long double head_dim_ld = (long double)head_dim;
62  long double log_base = logl(base_ld);
63  for (int pos = 0; pos < max_seq_len; ++pos) {
64  for (int i = 0; i < half_dim; ++i) {
65  long double exponent = ((long double)(2 * i)) / head_dim_ld;
66  long double freq = expl(-exponent * log_base);
67  float freq_f = (float)freq;
68  float angle_f = (float)pos * freq_f;
69  cos_cache[pos * half_dim + i] = cosf(angle_f);
70  sin_cache[pos * half_dim + i] = sinf(angle_f);
71  }
72  }
73 }

◆ scal_copy_f32()

void scal_copy_f32 ( float *  y,
const float *  x,
float  alpha,
int  n 
)

Scaled copy: y = alpha * x.

Parameters
yOutput vector [n]
xInput vector [n]
alphaScalar multiplier
nVector length

Definition at line 105 of file axpy_kernels.c.

109 {
110  if (!y || !x || n <= 0) {
111  return;
112  }
113 
114  int i = 0;
115 
116 #ifdef __AVX512F__
117  __m512 valpha = _mm512_set1_ps(alpha);
118  for (; i + 16 <= n; i += 16) {
119  __m512 vx = _mm512_loadu_ps(&x[i]);
120  __m512 vy = _mm512_mul_ps(vx, valpha);
121  _mm512_storeu_ps(&y[i], vy);
122  }
123 #endif
124 
125 #ifdef __AVX2__
126  __m256 valpha256 = _mm256_set1_ps(alpha);
127  for (; i + 8 <= n; i += 8) {
128  __m256 vx = _mm256_loadu_ps(&x[i]);
129  __m256 vy = _mm256_mul_ps(vx, valpha256);
130  _mm256_storeu_ps(&y[i], vy);
131  }
132 #endif
133 
134  for (; i < n; i++) {
135  y[i] = alpha * x[i];
136  }
137 }

Referenced by weighted_sum_f32().

◆ sigmoid_backward()

void sigmoid_backward ( const float *  input,
const float *  d_output,
float *  d_input,
size_t  n 
)

Definition at line 138 of file sigmoid_kernels.c.

142 {
143 #if defined(__AVX512F__)
144  sigmoid_backward_avx512(input, d_output, d_input, n);
145 #else
146  for (size_t i = 0; i < n; ++i) {
147  float x = input[i];
148  float s = sigmoid_scalar(x);
149  float s_prime = s * (1.0f - s);
150  d_input[i] = d_output[i] * s_prime;
151  }
152 #endif
153 }
float sigmoid_scalar(float x)

References sigmoid_scalar().

Referenced by sigmoid_backward_bf16().

◆ sigmoid_backward_bf16()

void sigmoid_backward_bf16 ( const uint16_t *  input,
const uint16_t *  d_output,
uint16_t *  d_input,
size_t  n,
float *  scratch_input,
float *  scratch_d_output,
float *  scratch_d_input 
)

Definition at line 45 of file sigmoid_kernels_bf16.c.

52 {
53  if (!input || !d_output || !d_input || n == 0) return;
54  if (!scratch_input || !scratch_d_output || !scratch_d_input) return;
55 
56  bf16_tensor_to_float(input, scratch_input, n);
57  bf16_tensor_to_float(d_output, scratch_d_output, n);
58  sigmoid_backward(scratch_input, scratch_d_output, scratch_d_input, n);
59  float_tensor_to_bf16(scratch_d_input, d_input, n);
60 }
void sigmoid_backward(const float *input, const float *d_output, float *d_input, size_t n)

References bf16_tensor_to_float(), float_tensor_to_bf16(), and sigmoid_backward().

◆ sigmoid_forward()

void sigmoid_forward ( const float *  input,
float *  output,
size_t  n 
)

Definition at line 122 of file sigmoid_kernels.c.

125 {
126 #if defined(__AVX512F__)
127  sigmoid_forward_avx512(input, output, n);
128 #else
129  for (size_t i = 0; i < n; ++i) {
130  output[i] = sigmoid_scalar(input[i]);
131  }
132 #endif
133 }

References sigmoid_scalar().

Referenced by sigmoid_forward_bf16().

◆ sigmoid_forward_bf16()

void sigmoid_forward_bf16 ( const uint16_t *  input,
uint16_t *  output,
size_t  n,
float *  scratch_input,
float *  scratch_output 
)

Definition at line 27 of file sigmoid_kernels_bf16.c.

32 {
33  if (!input || !output || n == 0) return;
34  if (!scratch_input || !scratch_output) return;
35 
36  bf16_tensor_to_float(input, scratch_input, n);
37  sigmoid_forward(scratch_input, scratch_output, n);
38  float_tensor_to_bf16(scratch_output, output, n);
39 }
void sigmoid_forward(const float *input, float *output, size_t n)

References bf16_tensor_to_float(), float_tensor_to_bf16(), and sigmoid_forward().

◆ sigmoid_scalar()

float sigmoid_scalar ( float  x)

◆ softmax_cross_entropy_loss()

void softmax_cross_entropy_loss ( const float *  logits,
const int32_t *  targets,
int  tokens,
int  vocab_size,
float *  d_logits,
float *  loss_out 
)

Definition at line 21 of file loss_kernels.c.

27 {
28  if (!logits || !targets || !d_logits || tokens <= 0 || vocab_size <= 0) {
29  if (loss_out) {
30  *loss_out = 0.0f;
31  }
32  return;
33  }
34 
35  double total_loss = 0.0;
36 
37  for (int t = 0; t < tokens; ++t) {
38  const float *row = logits + (size_t)t * (size_t)vocab_size;
39  float *drow = d_logits + (size_t)t * (size_t)vocab_size;
40  int target = targets[t];
41 
42  float max_logit = row[0];
43  for (int v = 1; v < vocab_size; ++v) {
44  if (row[v] > max_logit) {
45  max_logit = row[v];
46  }
47  }
48 
49  double sum_exp = 0.0;
50  for (int v = 0; v < vocab_size; ++v) {
51  float e = expf(row[v] - max_logit);
52  drow[v] = e;
53  sum_exp += e;
54  }
55 
56  float inv_sum = 1.0f / (float)sum_exp;
57  for (int v = 0; v < vocab_size; ++v) {
58  drow[v] *= inv_sum;
59  }
60 
61  if (target >= 0 && target < vocab_size) {
62  total_loss += -logf(drow[target] + 1e-10f);
63  drow[target] -= 1.0f;
64  }
65 
66  float scale = 1.0f / (float)tokens;
67  for (int v = 0; v < vocab_size; ++v) {
68  drow[v] *= scale;
69  }
70  }
71 
72  if (loss_out) {
73  *loss_out = (float)(total_loss / (double)tokens);
74  }
75 }

References vocab_size.

Referenced by softmax_cross_entropy_loss_bf16().

◆ softmax_cross_entropy_loss_bf16()

void softmax_cross_entropy_loss_bf16 ( const uint16_t *  logits,
const int32_t *  targets,
int  tokens,
int  vocab_size,
uint16_t *  d_logits,
float *  loss_out,
float *  scratch_logits,
float *  scratch_d_logits 
)

Definition at line 25 of file loss_kernels_bf16.c.

33 {
34  if (!logits || !targets || !d_logits || tokens <= 0 || vocab_size <= 0) {
35  if (loss_out) *loss_out = 0.0f;
36  return;
37  }
38  if (!scratch_logits || !scratch_d_logits) {
39  if (loss_out) *loss_out = 0.0f;
40  return;
41  }
42 
43  const size_t count = (size_t)tokens * (size_t)vocab_size;
44 
45  bf16_tensor_to_float(logits, scratch_logits, count);
46  softmax_cross_entropy_loss(scratch_logits, targets, tokens, vocab_size, scratch_d_logits, loss_out);
47  float_tensor_to_bf16(scratch_d_logits, d_logits, count);
48 }
void softmax_cross_entropy_loss(const float *logits, const int32_t *targets, int tokens, int vocab_size, float *d_logits, float *loss_out)
Definition: loss_kernels.c:21

References bf16_tensor_to_float(), float_tensor_to_bf16(), softmax_cross_entropy_loss(), and vocab_size.

◆ swiglu_backward()

void swiglu_backward ( const float *  input,
const float *  d_output,
float *  d_input,
int  tokens,
int  dim 
)

SwiGLU backward pass

Test:

test_swiglu.py::TestSwiGLUBackward::test_backward_tokens

test_swiglu.py::TestSwiGLUBackward::test_backward_single

test_parity.py::test_swiglu_backward_parity

Computes dGate and dUp given dY. dGate = dy * b * silu'(a), dUp = dy * silu(a)

After changes: make test && make llamacpp-parity-full

Definition at line 215 of file swiglu_kernels.c.

220 {
221  int T = tokens;
222  int D = dim;
223 
224  for (int t = 0; t < T; ++t) {
225  const float *row = input + (size_t)t * (2 * D);
226  const float *dy_row = d_output + (size_t)t * D;
227  float *dx_row = d_input + (size_t)t * (2 * D);
228  int d = 0;
229 
230 #if defined(__AVX512F__)
231  // AVX-512: Process 16 floats at a time
232  __m512 one = _mm512_set1_ps(1.0f);
233  for (; d + 16 <= D; d += 16) {
234  __m512 a = _mm512_loadu_ps(&row[d]); // gate
235  __m512 b = _mm512_loadu_ps(&row[D + d]); // value
236  __m512 dy = _mm512_loadu_ps(&dy_row[d]);
237 
238  __m512 s = sigmoid512_fast(a); // sigmoid(a)
239  __m512 silu = _mm512_mul_ps(a, s); // silu(a) = a * s
240  __m512 s_prime = _mm512_mul_ps(s, _mm512_sub_ps(one, s)); // s * (1 - s)
241  __m512 silu_prime = _mm512_fmadd_ps(a, s_prime, s); // s + a * s_prime
242 
243  // dA = dy * b * silu_prime
244  __m512 dA = _mm512_mul_ps(dy, _mm512_mul_ps(b, silu_prime));
245  // dB = dy * silu
246  __m512 dB = _mm512_mul_ps(dy, silu);
247 
248  _mm512_storeu_ps(&dx_row[d], dA);
249  _mm512_storeu_ps(&dx_row[D + d], dB);
250  }
251 #elif defined(__AVX2__)
252  // AVX2: Process 8 floats at a time
253  __m256 one = _mm256_set1_ps(1.0f);
254  for (; d + 8 <= D; d += 8) {
255  __m256 a = _mm256_loadu_ps(&row[d]); // gate
256  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
257  __m256 dy = _mm256_loadu_ps(&dy_row[d]);
258 
259  __m256 s = sigmoid256_fast(a); // sigmoid(a)
260  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * s
261  __m256 s_prime = _mm256_mul_ps(s, _mm256_sub_ps(one, s)); // s * (1 - s)
262  __m256 silu_prime = _mm256_fmadd_ps(a, s_prime, s); // s + a * s_prime
263 
264  // dA = dy * b * silu_prime
265  __m256 dA = _mm256_mul_ps(dy, _mm256_mul_ps(b, silu_prime));
266  // dB = dy * silu
267  __m256 dB = _mm256_mul_ps(dy, silu);
268 
269  _mm256_storeu_ps(&dx_row[d], dA);
270  _mm256_storeu_ps(&dx_row[D + d], dB);
271  }
272 #elif defined(__AVX__)
273  // AVX1: Vectorize arithmetic, use scalar sigmoid
274  __m256 one = _mm256_set1_ps(1.0f);
275  float a_arr[8] __attribute__((aligned(32)));
276  float s_arr[8] __attribute__((aligned(32)));
277 
278  for (; d + 8 <= D; d += 8) {
279  __m256 a = _mm256_loadu_ps(&row[d]); // gate
280  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
281  __m256 dy = _mm256_loadu_ps(&dy_row[d]);
282 
283  // Compute sigmoid scalarly
284  _mm256_store_ps(a_arr, a);
285  for (int j = 0; j < 8; ++j) {
286  s_arr[j] = sigmoid_scalar(a_arr[j]);
287  }
288  __m256 s = _mm256_load_ps(s_arr);
289 
290  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * s
291  __m256 s_prime = _mm256_mul_ps(s, _mm256_sub_ps(one, s)); // s * (1 - s)
292  // silu_prime = s + a * s_prime (no FMA in AVX1)
293  __m256 a_s_prime = _mm256_mul_ps(a, s_prime);
294  __m256 silu_prime = _mm256_add_ps(s, a_s_prime);
295 
296  // dA = dy * b * silu_prime
297  __m256 dA = _mm256_mul_ps(dy, _mm256_mul_ps(b, silu_prime));
298  // dB = dy * silu
299  __m256 dB = _mm256_mul_ps(dy, silu);
300 
301  _mm256_storeu_ps(&dx_row[d], dA);
302  _mm256_storeu_ps(&dx_row[D + d], dB);
303  }
304 #endif
305 
306  // Scalar fallback for remaining elements
307  for (; d < D; ++d) {
308  float a = row[d]; // gate
309  float b = row[D + d]; // value
310  float dy = dy_row[d];
311 
312  float s = sigmoid_scalar(a); // sigmoid(a)
313  float silu = a * s; // silu(a)
314  float s_prime = s * (1.0f - s); // sigmoid'(a)
315  float silu_prime = s + a * s_prime; // silu'(a)
316 
317  float dA = dy * b * silu_prime;
318  float dB = dy * silu;
319 
320  dx_row[d] = dA;
321  dx_row[D + d] = dB;
322  }
323  }
324 }
float sigmoid_scalar(float x)

References __attribute__(), sigmoid_scalar(), and silu().

Referenced by ck_layer_backward_rmsnorm_swiglu().

◆ swiglu_backward_bf16()

void swiglu_backward_bf16 ( const uint16_t *  input,
const uint16_t *  d_output,
uint16_t *  d_input,
int  tokens,
int  dim 
)

Definition at line 108 of file swiglu_kernels_bf16.c.

113 {
114  if (!input || !d_output || !d_input || tokens <= 0 || dim <= 0) {
115  return;
116  }
117 
118  const int T = tokens;
119  const int D = dim;
120 
121  for (int t = 0; t < T; ++t) {
122  const uint16_t *row = input + (size_t)t * (size_t)(2 * D);
123  const uint16_t *dy_row = d_output + (size_t)t * (size_t)D;
124  uint16_t *dx_row = d_input + (size_t)t * (size_t)(2 * D);
125  int d = 0;
126 
127 #if defined(__AVX512F__)
128  // AVX-512: Process 16 floats at a time
129  __m512 one = _mm512_set1_ps(1.0f);
130  for (; d + 16 <= D; d += 16) {
131  __m512 a = bf16_loadu_cvt_fp32(&row[d]); // gate
132  __m512 b = bf16_loadu_cvt_fp32(&row[D + d]); // value
133  __m512 dy = bf16_loadu_cvt_fp32(&dy_row[d]);
134 
135  __m512 s = sigmoid512_fast_bf16(a); // sigmoid(a)
136  __m512 silu = _mm512_mul_ps(a, s); // silu(a) = a * s
137  __m512 s_prime = _mm512_mul_ps(s, _mm512_sub_ps(one, s)); // s * (1 - s)
138  __m512 silu_prime = _mm512_fmadd_ps(a, s_prime, s); // s + a * s_prime
139 
140  // dA = dy * b * silu_prime
141  __m512 dA = _mm512_mul_ps(dy, _mm512_mul_ps(b, silu_prime));
142  // dB = dy * silu
143  __m512 dB = _mm512_mul_ps(dy, silu);
144 
145  fp32_cvt_storeu_bf16(&dx_row[d], dA);
146  fp32_cvt_storeu_bf16(&dx_row[D + d], dB);
147  }
148 #endif
149 
150  // Scalar fallback for remaining elements
151  for (; d < D; ++d) {
152  float a = bf16_to_float(row[d]);
153  float b = bf16_to_float(row[D + d]);
154  float dy = bf16_to_float(dy_row[d]);
155 
156  float s = sigmoid_scalar(a);
157  float silu = a * s;
158  float s_prime = s * (1.0f - s);
159  float silu_prime = s + a * s_prime;
160 
161  float dA = dy * b * silu_prime;
162  float dB = dy * silu;
163 
164  dx_row[d] = float_to_bf16(dA);
165  dx_row[D + d] = float_to_bf16(dB);
166  }
167  }
168 }

References bf16_to_float(), float_to_bf16(), sigmoid_scalar(), and silu().

◆ swiglu_backward_exact()

void swiglu_backward_exact ( const float *  input,
const float *  d_output,
float *  d_input,
int  tokens,
int  dim 
)

SwiGLU backward pass (exact version using stdlib sigmoid)

Test:

test_swiglu.py::TestSwiGLUBackward::test_exact_vs_fast

test_swiglu.py::TestSwiGLUBackward::test_exact_single

Uses standard library expf for numerical accuracy reference.

After changes: make test

Definition at line 373 of file swiglu_kernels.c.

378 {
379  int T = tokens;
380  int D = dim;
381 
382  for (int t = 0; t < T; ++t) {
383  const float *row = input + (size_t)t * (2 * D);
384  const float *dy_row = d_output + (size_t)t * D;
385  float *dx_row = d_input + (size_t)t * (2 * D);
386 
387  for (int d = 0; d < D; ++d) {
388  float a = row[d]; // gate
389  float b = row[D + d]; // value
390  float dy = dy_row[d];
391 
392  // Use standard library expf via sigmoid_scalar
393  float s = sigmoid_scalar(a); // sigmoid(a)
394  float silu = a * s; // silu(a)
395  float s_prime = s * (1.0f - s); // sigmoid'(a)
396  float silu_prime = s + a * s_prime; // silu'(a)
397 
398  float dA = dy * b * silu_prime;
399  float dB = dy * silu;
400 
401  dx_row[d] = dA;
402  dx_row[D + d] = dB;
403  }
404  }
405 }

References sigmoid_scalar(), and silu().

◆ swiglu_forward()

void swiglu_forward ( const float *  input,
float *  output,
int  tokens,
int  dim 
)

SwiGLU forward pass

Test:

test_swiglu.py::TestSwiGLUForward::test_forward_tokens

test_swiglu.py::TestSwiGLUForward::test_forward_single

test_mlp.py::TestMLPForward::test_swiglu_mlp

test_fused_swiglu_decode.py::TestFusedSwiGLUDecode::test_fused_swiglu_decode

test_parity.py::test_swiglu_parity

SwiGLU: y = silu(gate) * up where silu(x) = x * sigmoid(x)

After changes: make test && make llamacpp-parity-full

Definition at line 131 of file swiglu_kernels.c.

135 {
136  int T = tokens;
137  int D = dim;
138 
139  for (int t = 0; t < T; ++t) {
140  const float *row = input + (size_t)t * (2 * D);
141  float *out_row = output + (size_t)t * D;
142  int d = 0;
143 
144 #if defined(__AVX512F__)
145  // AVX-512: Process 16 floats at a time
146  for (; d + 16 <= D; d += 16) {
147  __m512 a = _mm512_loadu_ps(&row[d]); // gate
148  __m512 b = _mm512_loadu_ps(&row[D + d]); // value
149 
150  __m512 s = sigmoid512_fast(a); // sigmoid(a)
151  __m512 silu = _mm512_mul_ps(a, s); // silu(a) = a * sigmoid(a)
152  __m512 y = _mm512_mul_ps(silu, b); // y = silu(a) * b
153 
154  _mm512_storeu_ps(&out_row[d], y);
155  }
156 #elif defined(__AVX2__)
157  // AVX2: Process 8 floats at a time
158  for (; d + 8 <= D; d += 8) {
159  __m256 a = _mm256_loadu_ps(&row[d]); // gate
160  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
161 
162  __m256 s = sigmoid256_fast(a); // sigmoid(a)
163  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * sigmoid(a)
164  __m256 y = _mm256_mul_ps(silu, b); // y = silu(a) * b
165 
166  _mm256_storeu_ps(&out_row[d], y);
167  }
168 #elif defined(__AVX__)
169  // AVX1: Vectorize arithmetic, use scalar sigmoid
170  float a_arr[8] __attribute__((aligned(32)));
171  float s_arr[8] __attribute__((aligned(32)));
172 
173  for (; d + 8 <= D; d += 8) {
174  __m256 a = _mm256_loadu_ps(&row[d]); // gate
175  __m256 b = _mm256_loadu_ps(&row[D + d]); // value
176 
177  // Compute sigmoid scalarly
178  _mm256_store_ps(a_arr, a);
179  for (int j = 0; j < 8; ++j) {
180  s_arr[j] = sigmoid_scalar(a_arr[j]);
181  }
182  __m256 s = _mm256_load_ps(s_arr);
183 
184  __m256 silu = _mm256_mul_ps(a, s); // silu(a) = a * sigmoid(a)
185  __m256 y = _mm256_mul_ps(silu, b); // y = silu(a) * b
186 
187  _mm256_storeu_ps(&out_row[d], y);
188  }
189 #endif
190 
191  // Scalar fallback for remaining elements
192  for (; d < D; ++d) {
193  float a = row[d]; // gate
194  float b = row[D + d]; // value
195 
196  float s = sigmoid_scalar(a); // sigmoid(a)
197  float silu = a * s; // silu(a) = a * sigmoid(a)
198 
199  out_row[d] = silu * b;
200  }
201  }
202 }

References __attribute__(), sigmoid_scalar(), and silu().

Referenced by ck_mlp_swiglu_forward(), ck_mlp_swiglu_forward_q4_k(), ck_mlp_swiglu_forward_q4_k_q8_k(), ck_mlp_swiglu_forward_q4_k_q8_k_prefill(), ck_mlp_swiglu_forward_quant(), ck_mlp_swiglu_forward_ref(), model_layer_0_decode(), model_layer_0_prefill(), model_layer_10_decode(), model_layer_10_prefill(), model_layer_11_decode(), model_layer_11_prefill(), model_layer_12_decode(), model_layer_12_prefill(), model_layer_13_decode(), model_layer_13_prefill(), model_layer_14_decode(), model_layer_14_prefill(), model_layer_15_decode(), model_layer_15_prefill(), model_layer_16_decode(), model_layer_16_prefill(), model_layer_17_decode(), model_layer_17_prefill(), model_layer_18_decode(), model_layer_18_prefill(), model_layer_19_decode(), model_layer_19_prefill(), model_layer_1_decode(), model_layer_1_prefill(), model_layer_20_decode(), model_layer_20_prefill(), model_layer_21_decode(), model_layer_21_prefill(), model_layer_22_decode(), model_layer_22_prefill(), model_layer_23_decode(), model_layer_23_prefill(), model_layer_2_decode(), model_layer_2_prefill(), model_layer_3_decode(), model_layer_3_prefill(), model_layer_4_decode(), model_layer_4_prefill(), model_layer_5_decode(), model_layer_5_prefill(), model_layer_6_decode(), model_layer_6_prefill(), model_layer_7_decode(), model_layer_7_prefill(), model_layer_8_decode(), model_layer_8_prefill(), model_layer_9_decode(), model_layer_9_prefill(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().

◆ swiglu_forward_bf16()

void swiglu_forward_bf16 ( const uint16_t *  input,
uint16_t *  output,
int  tokens,
int  dim 
)

Definition at line 66 of file swiglu_kernels_bf16.c.

70 {
71  if (!input || !output || tokens <= 0 || dim <= 0) {
72  return;
73  }
74 
75  const int T = tokens;
76  const int D = dim;
77 
78  for (int t = 0; t < T; ++t) {
79  const uint16_t *row = input + (size_t)t * (size_t)(2 * D);
80  uint16_t *out_row = output + (size_t)t * (size_t)D;
81  int d = 0;
82 
83 #if defined(__AVX512F__)
84  // AVX-512: Process 16 floats at a time
85  for (; d + 16 <= D; d += 16) {
86  __m512 a = bf16_loadu_cvt_fp32(&row[d]); // gate
87  __m512 b = bf16_loadu_cvt_fp32(&row[D + d]); // value
88 
89  __m512 s = sigmoid512_fast_bf16(a); // sigmoid(a)
90  __m512 silu = _mm512_mul_ps(a, s); // silu(a) = a * sigmoid(a)
91  __m512 y = _mm512_mul_ps(silu, b); // y = silu(a) * b
92 
93  fp32_cvt_storeu_bf16(&out_row[d], y);
94  }
95 #endif
96 
97  // Scalar fallback for remaining elements
98  for (; d < D; ++d) {
99  float a = bf16_to_float(row[d]);
100  float b = bf16_to_float(row[D + d]);
101  float s = sigmoid_scalar(a);
102  float silu = a * s;
103  out_row[d] = float_to_bf16(silu * b);
104  }
105  }
106 }

References bf16_to_float(), float_to_bf16(), sigmoid_scalar(), and silu().

◆ swiglu_forward_exact()

void swiglu_forward_exact ( const float *  input,
float *  output,
int  tokens,
int  dim 
)

SwiGLU forward pass (exact version using stdlib sigmoid)

Test:

test_swiglu.py::TestSwiGLUForward::test_exact_vs_fast

test_swiglu.py::TestSwiGLUForward::test_exact_single

Uses standard library expf for numerical accuracy reference.

After changes: make test

Definition at line 339 of file swiglu_kernels.c.

343 {
344  int T = tokens;
345  int D = dim;
346 
347  for (int t = 0; t < T; ++t) {
348  const float *row = input + (size_t)t * (2 * D);
349  float *out_row = output + (size_t)t * D;
350 
351  for (int d = 0; d < D; ++d) {
352  float a = row[d]; // gate
353  float b = row[D + d]; // value
354 
355  // Use standard library expf via sigmoid_scalar
356  float s = sigmoid_scalar(a); // sigmoid(a) = 1/(1+expf(-a))
357  float silu = a * s; // silu(a) = a * sigmoid(a)
358 
359  out_row[d] = silu * b;
360  }
361  }
362 }

References sigmoid_scalar(), and silu().

◆ topk_batched_f32()

void topk_batched_f32 ( const float *  scores,
int  num_tokens,
int  n_experts,
int  k,
int *  indices,
float *  weights 
)

Batched top-K selection for multiple tokens.

Parameters
scoresInput scores [num_tokens, n_experts]
num_tokensNumber of tokens
n_expertsNumber of experts
kNumber of experts to select per token
indicesOutput: selected expert indices [num_tokens, k]
weightsOutput: routing weights [num_tokens, k] (can be NULL for no softmax)

Definition at line 191 of file topk_kernels.c.

197 {
198  if (!scores || !indices || num_tokens <= 0 || n_experts <= 0 || k <= 0) {
199  return;
200  }
201 
202  for (int t = 0; t < num_tokens; t++) {
203  const float *token_scores = scores + t * n_experts;
204  int *token_indices = indices + t * k;
205 
206  if (weights) {
207  float *token_weights = weights + t * k;
208  topk_softmax_f32(token_scores, n_experts, k, token_indices, token_weights);
209  } else {
210  topk_f32(token_scores, n_experts, k, token_indices, NULL);
211  }
212  }
213 }
void topk_f32(const float *scores, int n, int k, int *indices, float *values)
Find top-K indices and values from a score vector.
Definition: topk_kernels.c:49
void topk_softmax_f32(const float *scores, int n, int k, int *indices, float *weights)
Find top-K indices with softmax-normalized weights.
Definition: topk_kernels.c:134

References topk_f32(), and topk_softmax_f32().

◆ topk_f32()

void topk_f32 ( const float *  scores,
int  n,
int  k,
int *  indices,
float *  values 
)

Find top-K indices and values from a score vector.

Parameters
scoresInput scores [n]
nNumber of scores (e.g., number of experts)
kNumber of top scores to select
indicesOutput: indices of top-K scores [k], sorted descending by value
valuesOutput: top-K score values [k], sorted descending (can be NULL)

Definition at line 49 of file topk_kernels.c.

54 {
55  if (!scores || !indices || n <= 0 || k <= 0) {
56  return;
57  }
58 
59  /* Clamp k to n */
60  if (k > n) {
61  k = n;
62  }
63 
64  /* Initialize with first k elements */
65  float local_values[k];
66  for (int i = 0; i < k; i++) {
67  indices[i] = i;
68  local_values[i] = scores[i];
69  }
70 
71  /* Find the minimum in our current top-k */
72  int min_idx = 0;
73  for (int i = 1; i < k; i++) {
74  if (local_values[i] < local_values[min_idx]) {
75  min_idx = i;
76  }
77  }
78 
79  /* Scan remaining elements */
80  for (int i = k; i < n; i++) {
81  if (scores[i] > local_values[min_idx]) {
82  /* Replace the minimum */
83  indices[min_idx] = i;
84  local_values[min_idx] = scores[i];
85 
86  /* Find new minimum */
87  min_idx = 0;
88  for (int j = 1; j < k; j++) {
89  if (local_values[j] < local_values[min_idx]) {
90  min_idx = j;
91  }
92  }
93  }
94  }
95 
96  /* Sort results in descending order (simple insertion sort for small k) */
97  for (int i = 1; i < k; i++) {
98  float val = local_values[i];
99  int idx = indices[i];
100  int j = i - 1;
101  while (j >= 0 && local_values[j] < val) {
102  local_values[j + 1] = local_values[j];
103  indices[j + 1] = indices[j];
104  j--;
105  }
106  local_values[j + 1] = val;
107  indices[j + 1] = idx;
108  }
109 
110  /* Copy values if output requested */
111  if (values) {
112  for (int i = 0; i < k; i++) {
113  values[i] = local_values[i];
114  }
115  }
116 }

Referenced by topk_batched_f32(), and topk_softmax_f32().

◆ topk_softmax_f32()

void topk_softmax_f32 ( const float *  scores,
int  n,
int  k,
int *  indices,
float *  weights 
)

Find top-K indices with softmax-normalized weights.

Parameters
scoresInput scores [n] (router logits)
nNumber of scores
kNumber of top scores to select
indicesOutput: indices of top-K scores [k]
weightsOutput: softmax-normalized weights for selected [k], sum to 1.0

Definition at line 134 of file topk_kernels.c.

139 {
140  if (!scores || !indices || !weights || n <= 0 || k <= 0) {
141  return;
142  }
143 
144  if (k > n) {
145  k = n;
146  }
147 
148  /* First get top-K indices and values */
149  float values[k];
150  topk_f32(scores, n, k, indices, values);
151 
152  /* Compute softmax over the selected values */
153  /* Find max for numerical stability */
154  float max_val = values[0];
155  for (int i = 1; i < k; i++) {
156  if (values[i] > max_val) {
157  max_val = values[i];
158  }
159  }
160 
161  /* Compute exp and sum */
162  float sum = 0.0f;
163  for (int i = 0; i < k; i++) {
164  weights[i] = expf(values[i] - max_val);
165  sum += weights[i];
166  }
167 
168  /* Normalize */
169  float inv_sum = 1.0f / sum;
170  for (int i = 0; i < k; i++) {
171  weights[i] *= inv_sum;
172  }
173 }

References topk_f32().

Referenced by topk_batched_f32().

◆ unfused_rmsnorm_qkv_prefill()

void unfused_rmsnorm_qkv_prefill ( const float *  x,
const float *  gamma,
const float *  Wq,
const float *  Wk,
const float *  Wv,
float *  x_norm,
float *  Q,
float *  K,
float *  V,
int  seq_len,
int  hidden,
int  q_dim,
int  kv_dim,
float  eps 
)

Unfused version for benchmarking comparison.

Unfused version for benchmarking comparison.

Definition at line 667 of file prefill_fused_gemm.c.

682 {
683  /* Step 1: Full RMSNorm → writes x_norm to memory */
684  rmsnorm_tile(x, gamma, x_norm, seq_len, hidden, hidden, eps);
685 
686  /* Step 2: Separate GEMMs with N-outer tiling for weight reuse */
687  /* Q projection */
688  for (int n_start = 0; n_start < q_dim; n_start += PREFILL_TILE_N) {
689  int tile_n = (n_start + PREFILL_TILE_N <= q_dim)
690  ? PREFILL_TILE_N : (q_dim - n_start);
691  const float *W_tile = Wq + (size_t)n_start * hidden;
692 
693  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
694  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
695  ? PREFILL_TILE_M : (seq_len - m_start);
696  const float *x_tile = x_norm + (size_t)m_start * hidden;
697  float *out_tile = Q + (size_t)m_start * q_dim + n_start;
698  gemm_tile_nt_strided(x_tile, W_tile, out_tile,
699  tile_m, tile_n, hidden, q_dim);
700  }
701  }
702 
703  /* K projection */
704  for (int n_start = 0; n_start < kv_dim; n_start += PREFILL_TILE_N) {
705  int tile_n = (n_start + PREFILL_TILE_N <= kv_dim)
706  ? PREFILL_TILE_N : (kv_dim - n_start);
707  const float *W_tile = Wk + (size_t)n_start * hidden;
708 
709  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
710  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
711  ? PREFILL_TILE_M : (seq_len - m_start);
712  const float *x_tile = x_norm + (size_t)m_start * hidden;
713  float *out_tile = K + (size_t)m_start * kv_dim + n_start;
714  gemm_tile_nt_strided(x_tile, W_tile, out_tile,
715  tile_m, tile_n, hidden, kv_dim);
716  }
717  }
718 
719  /* V projection */
720  for (int n_start = 0; n_start < kv_dim; n_start += PREFILL_TILE_N) {
721  int tile_n = (n_start + PREFILL_TILE_N <= kv_dim)
722  ? PREFILL_TILE_N : (kv_dim - n_start);
723  const float *W_tile = Wv + (size_t)n_start * hidden;
724 
725  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
726  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
727  ? PREFILL_TILE_M : (seq_len - m_start);
728  const float *x_tile = x_norm + (size_t)m_start * hidden;
729  float *out_tile = V + (size_t)m_start * kv_dim + n_start;
730  gemm_tile_nt_strided(x_tile, W_tile, out_tile,
731  tile_m, tile_n, hidden, kv_dim);
732  }
733  }
734 }
#define PREFILL_TILE_N

References gemm_tile_nt_strided(), PREFILL_TILE_M, PREFILL_TILE_N, and rmsnorm_tile().

◆ vec_dot_q6_k_q8_k()

void vec_dot_q6_k_q8_k ( int  n,
float *  s,
const void *  vx,
const void *  vy 
)

Q6_K x Q8_K dot product (single row)

Definition at line 954 of file gemm_kernels_q6k_q8k.c.

955 {
956  if (!s || !vx || !vy || n <= 0) {
957  return;
958  }
959 
960  const block_q6_K *x = (const block_q6_K *)vx;
961  const block_q8_K *y = (const block_q8_K *)vy;
962 
963  /* Dispatch based on available SIMD */
964 #if defined(__AVX512F__) && defined(__AVX512BW__)
965  *s = dot_q6_k_q8_k_avx512(x, y, n);
966 #elif defined(__AVX2__)
967  *s = dot_q6_k_q8_k_avx2(x, y, n);
968 #elif defined(__AVX__) && !defined(__AVX2__)
969  *s = dot_q6_k_q8_k_avx(x, y, n);
970 #elif defined(__SSSE3__)
971  *s = dot_q6_k_q8_k_sse(x, y, n);
972 #else
973  *s = dot_q6_k_q8_k_ref(x, y, n);
974 #endif
975 }

References dot_q6_k_q8_k_ref().

◆ weighted_sum_f32()

void weighted_sum_f32 ( float *  y,
const float **  vectors,
const float *  weights,
int  k,
int  n 
)

Weighted sum of k vectors: y = sum_i(weights[i] * vectors[i])

Parameters
yOutput vector [n]
vectorsArray of k input vector pointers, each [n]
weightsArray of k scalar weights
kNumber of vectors to combine
nVector length

Definition at line 155 of file axpy_kernels.c.

160 {
161  if (!y || !vectors || !weights || k <= 0 || n <= 0) {
162  return;
163  }
164 
165  /* Initialize with first vector */
166  scal_copy_f32(y, vectors[0], weights[0], n);
167 
168  /* Accumulate rest */
169  for (int i = 1; i < k; i++) {
170  axpy_f32(y, vectors[i], weights[i], n);
171  }
172 }
void scal_copy_f32(float *y, const float *x, float alpha, int n)
Scaled copy: y = alpha * x.
Definition: axpy_kernels.c:105

References axpy_f32(), and scal_copy_f32().