← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rmsnorm_q8_k_fused.c File Reference

Fused RMSNorm + Q8_K Quantization kernel. More...

#include <immintrin.h>
#include <math.h>
#include <stdint.h>
#include <string.h>
#include "ckernel_quant.h"

Go to the source code of this file.

Functions

static float hmax256_ps_fused (__m256 v)
 
static float hsum256_ps_fused (__m256 v)
 
void rmsnorm_q8_k_fused (const float *input, const float *gamma, void *vy, int tokens, int d_model, int aligned_embed_dim, float eps)
 

Detailed Description

Fused RMSNorm + Q8_K Quantization kernel.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

FUSION BENEFIT: Eliminates intermediate FP32 buffer between RMSNorm and quantization, keeping normalized values in registers/L1.

Definition in file rmsnorm_q8_k_fused.c.

Function Documentation

◆ hmax256_ps_fused()

static float hmax256_ps_fused ( __m256  v)
inlinestatic

Definition at line 37 of file rmsnorm_q8_k_fused.c.

37  {
38  __m128 hi = _mm256_extractf128_ps(v, 1);
39  __m128 lo = _mm256_castps256_ps128(v);
40  __m128 max128 = _mm_max_ps(lo, hi);
41  max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, _MM_SHUFFLE(1, 0, 3, 2)));
42  max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, _MM_SHUFFLE(0, 1, 0, 1)));
43  return _mm_cvtss_f32(max128);
44 }

Referenced by rmsnorm_q8_k_fused().

◆ hsum256_ps_fused()

static float hsum256_ps_fused ( __m256  v)
inlinestatic

Definition at line 27 of file rmsnorm_q8_k_fused.c.

27  {
28  __m128 hi = _mm256_extractf128_ps(v, 1);
29  __m128 lo = _mm256_castps256_ps128(v);
30  __m128 sum128 = _mm_add_ps(lo, hi);
31  sum128 = _mm_hadd_ps(sum128, sum128);
32  sum128 = _mm_hadd_ps(sum128, sum128);
33  return _mm_cvtss_f32(sum128);
34 }

Referenced by fused_rmsnorm_linear_q4k(), gemm_bias_gelu_fused(), gemm_bias_relu_fused(), gemm_bias_silu_fused(), gemm_swiglu_fused(), and rmsnorm_q8_k_fused().

◆ rmsnorm_q8_k_fused()

void rmsnorm_q8_k_fused ( const float *  input,
const float *  gamma,
void *  vy,
int  tokens,
int  d_model,
int  aligned_embed_dim,
float  eps 
)

Fused RMSNorm + Q8_K Quantization

Benefits:

  1. Single pass over input data (reduces DRAM pressure)
  2. Normalization results stay in registers for quantization
  3. Keeps hot data in L1/L2 cache

Definition at line 54 of file rmsnorm_q8_k_fused.c.

61 {
62  const int T = tokens;
63  const int D = d_model;
64  block_q8_K *y = (block_q8_K *)vy;
65 
66  for (int t = 0; t < T; ++t) {
67  const float *x = input + (size_t)t * aligned_embed_dim;
68 
69  // 1. Compute sum of squares using AVX
70  __m256 sum_sq_vec = _mm256_setzero_ps();
71  for (int d = 0; d < D; d += 8) {
72  __m256 xv = _mm256_loadu_ps(&x[d]);
73  sum_sq_vec = _mm256_add_ps(sum_sq_vec, _mm256_mul_ps(xv, xv));
74  }
75  float sum_sq = hsum256_ps_fused(sum_sq_vec);
76  float rstd = 1.0f / sqrtf(sum_sq / (float)D + eps);
77  __m256 vrstd = _mm256_set1_ps(rstd);
78 
79  // 2. We need the max absolute value of the NORMALIZED data for quantization
80  // y_i = gamma_i * (x_i * rstd)
81  // We do this in blocks of QK_K (256) to match Q8_K layout
82  for (int b = 0; b < D / QK_K; ++b) {
83  const float *xb = x + b * QK_K;
84  const float *gb = gamma + b * QK_K;
85  block_q8_K *out_block = &y[t * (D / QK_K) + b];
86 
87  // Local normalization and max search
88  __m256 v_max_abs = _mm256_setzero_ps();
89  float norm_buf[QK_K];
90 
91  for (int d = 0; d < QK_K; d += 8) {
92  __m256 xv = _mm256_loadu_ps(&xb[d]);
93  __m256 gv = _mm256_loadu_ps(&gb[d]);
94  __m256 normalized = _mm256_mul_ps(_mm256_mul_ps(xv, vrstd), gv);
95 
96  _mm256_storeu_ps(&norm_buf[d], normalized);
97 
98  __m256 v_abs = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), normalized);
99  v_max_abs = _mm256_max_ps(v_max_abs, v_abs);
100  }
101 
102  float max_val = hmax256_ps_fused(v_max_abs);
103  if (max_val == 0.0f) {
104  out_block->d = 0.0f;
105  memset(out_block->qs, 0, QK_K);
106  memset(out_block->bsums, 0, sizeof(out_block->bsums));
107  continue;
108  }
109 
110  // 3. Quantize to Q8_K
111  float iscale = -127.0f / max_val;
112  __m256 v_iscale = _mm256_set1_ps(iscale);
113  out_block->d = 1.0f / iscale;
114 
115  for (int j = 0; j < QK_K; j += 16) {
116  // AVX1 doesn't have 256-bit integer conversion, so we use 128-bit SSE for packing
117  __m128 n0 = _mm_loadu_ps(&norm_buf[j + 0]);
118  __m128 n1 = _mm_loadu_ps(&norm_buf[j + 4]);
119  __m128 n2 = _mm_loadu_ps(&norm_buf[j + 8]);
120  __m128 n3 = _mm_loadu_ps(&norm_buf[j + 12]);
121 
122  __m128i q0 = _mm_cvtps_epi32(_mm_mul_ps(n0, _mm256_castps256_ps128(v_iscale)));
123  __m128i q1 = _mm_cvtps_epi32(_mm_mul_ps(n1, _mm256_castps256_ps128(v_iscale)));
124  __m128i q2 = _mm_cvtps_epi32(_mm_mul_ps(n2, _mm256_castps256_ps128(v_iscale)));
125  __m128i q3 = _mm_cvtps_epi32(_mm_mul_ps(n3, _mm256_castps256_ps128(v_iscale)));
126 
127  __m128i q01 = _mm_packs_epi32(q0, q1);
128  __m128i q23 = _mm_packs_epi32(q2, q3);
129  __m128i q0123 = _mm_packs_epi16(q01, q23);
130 
131  _mm_storeu_si128((__m128i *)(out_block->qs + j), q0123);
132 
133  // Compute bsum for 16 elements
134  __m128i p01 = _mm_add_epi16(q01, q23);
135  p01 = _mm_add_epi16(p01, _mm_shuffle_epi32(p01, _MM_SHUFFLE(1, 0, 3, 2)));
136  p01 = _mm_add_epi16(p01, _mm_shufflelo_epi16(p01, _MM_SHUFFLE(1, 0, 3, 2)));
137  int16_t bsum = (int16_t)_mm_extract_epi16(p01, 0) + (int16_t)_mm_extract_epi16(p01, 1);
138  out_block->bsums[j / 16] = bsum;
139  }
140  }
141  }
142 }
#define QK_K
static float hmax256_ps_fused(__m256 v)
static float hsum256_ps_fused(__m256 v)
int8_t qs[256]
int16_t bsums[256/16]

References block_q8_K::bsums, block_q8_K::d, hmax256_ps_fused(), hsum256_ps_fused(), QK_K, and block_q8_K::qs.