← Back to C-Kernel-Engine Docs Doxygen Source Documentation
rmsnorm_q8_k_fused.c
Go to the documentation of this file.
1 /**
2  * @file rmsnorm_q8_k_fused.c
3  * @brief Fused RMSNorm + Q8_K Quantization kernel
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * FUSION BENEFIT: Eliminates intermediate FP32 buffer between RMSNorm and
15  * quantization, keeping normalized values in registers/L1.
16  */
17 
18 #pragma GCC target("avx,sse4.1")
19 #include <immintrin.h>
20 #include <math.h>
21 #include <stdint.h>
22 #include <string.h>
23 
24 #include "ckernel_quant.h"
25 
26 // AVX1 horizontal sum helper
27 static inline float hsum256_ps_fused(__m256 v) {
28  __m128 hi = _mm256_extractf128_ps(v, 1);
29  __m128 lo = _mm256_castps256_ps128(v);
30  __m128 sum128 = _mm_add_ps(lo, hi);
31  sum128 = _mm_hadd_ps(sum128, sum128);
32  sum128 = _mm_hadd_ps(sum128, sum128);
33  return _mm_cvtss_f32(sum128);
34 }
35 
36 // AVX1 horizontal max helper
37 static inline float hmax256_ps_fused(__m256 v) {
38  __m128 hi = _mm256_extractf128_ps(v, 1);
39  __m128 lo = _mm256_castps256_ps128(v);
40  __m128 max128 = _mm_max_ps(lo, hi);
41  max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, _MM_SHUFFLE(1, 0, 3, 2)));
42  max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, _MM_SHUFFLE(0, 1, 0, 1)));
43  return _mm_cvtss_f32(max128);
44 }
45 
46 /**
47  * Fused RMSNorm + Q8_K Quantization
48  *
49  * Benefits:
50  * 1. Single pass over input data (reduces DRAM pressure)
51  * 2. Normalization results stay in registers for quantization
52  * 3. Keeps hot data in L1/L2 cache
53  */
54 void rmsnorm_q8_k_fused(const float *input,
55  const float *gamma,
56  void *vy,
57  int tokens,
58  int d_model,
59  int aligned_embed_dim,
60  float eps)
61 {
62  const int T = tokens;
63  const int D = d_model;
64  block_q8_K *y = (block_q8_K *)vy;
65 
66  for (int t = 0; t < T; ++t) {
67  const float *x = input + (size_t)t * aligned_embed_dim;
68 
69  // 1. Compute sum of squares using AVX
70  __m256 sum_sq_vec = _mm256_setzero_ps();
71  for (int d = 0; d < D; d += 8) {
72  __m256 xv = _mm256_loadu_ps(&x[d]);
73  sum_sq_vec = _mm256_add_ps(sum_sq_vec, _mm256_mul_ps(xv, xv));
74  }
75  float sum_sq = hsum256_ps_fused(sum_sq_vec);
76  float rstd = 1.0f / sqrtf(sum_sq / (float)D + eps);
77  __m256 vrstd = _mm256_set1_ps(rstd);
78 
79  // 2. We need the max absolute value of the NORMALIZED data for quantization
80  // y_i = gamma_i * (x_i * rstd)
81  // We do this in blocks of QK_K (256) to match Q8_K layout
82  for (int b = 0; b < D / QK_K; ++b) {
83  const float *xb = x + b * QK_K;
84  const float *gb = gamma + b * QK_K;
85  block_q8_K *out_block = &y[t * (D / QK_K) + b];
86 
87  // Local normalization and max search
88  __m256 v_max_abs = _mm256_setzero_ps();
89  float norm_buf[QK_K];
90 
91  for (int d = 0; d < QK_K; d += 8) {
92  __m256 xv = _mm256_loadu_ps(&xb[d]);
93  __m256 gv = _mm256_loadu_ps(&gb[d]);
94  __m256 normalized = _mm256_mul_ps(_mm256_mul_ps(xv, vrstd), gv);
95 
96  _mm256_storeu_ps(&norm_buf[d], normalized);
97 
98  __m256 v_abs = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), normalized);
99  v_max_abs = _mm256_max_ps(v_max_abs, v_abs);
100  }
101 
102  float max_val = hmax256_ps_fused(v_max_abs);
103  if (max_val == 0.0f) {
104  out_block->d = 0.0f;
105  memset(out_block->qs, 0, QK_K);
106  memset(out_block->bsums, 0, sizeof(out_block->bsums));
107  continue;
108  }
109 
110  // 3. Quantize to Q8_K
111  float iscale = -127.0f / max_val;
112  __m256 v_iscale = _mm256_set1_ps(iscale);
113  out_block->d = 1.0f / iscale;
114 
115  for (int j = 0; j < QK_K; j += 16) {
116  // AVX1 doesn't have 256-bit integer conversion, so we use 128-bit SSE for packing
117  __m128 n0 = _mm_loadu_ps(&norm_buf[j + 0]);
118  __m128 n1 = _mm_loadu_ps(&norm_buf[j + 4]);
119  __m128 n2 = _mm_loadu_ps(&norm_buf[j + 8]);
120  __m128 n3 = _mm_loadu_ps(&norm_buf[j + 12]);
121 
122  __m128i q0 = _mm_cvtps_epi32(_mm_mul_ps(n0, _mm256_castps256_ps128(v_iscale)));
123  __m128i q1 = _mm_cvtps_epi32(_mm_mul_ps(n1, _mm256_castps256_ps128(v_iscale)));
124  __m128i q2 = _mm_cvtps_epi32(_mm_mul_ps(n2, _mm256_castps256_ps128(v_iscale)));
125  __m128i q3 = _mm_cvtps_epi32(_mm_mul_ps(n3, _mm256_castps256_ps128(v_iscale)));
126 
127  __m128i q01 = _mm_packs_epi32(q0, q1);
128  __m128i q23 = _mm_packs_epi32(q2, q3);
129  __m128i q0123 = _mm_packs_epi16(q01, q23);
130 
131  _mm_storeu_si128((__m128i *)(out_block->qs + j), q0123);
132 
133  // Compute bsum for 16 elements
134  __m128i p01 = _mm_add_epi16(q01, q23);
135  p01 = _mm_add_epi16(p01, _mm_shuffle_epi32(p01, _MM_SHUFFLE(1, 0, 3, 2)));
136  p01 = _mm_add_epi16(p01, _mm_shufflelo_epi16(p01, _MM_SHUFFLE(1, 0, 3, 2)));
137  int16_t bsum = (int16_t)_mm_extract_epi16(p01, 0) + (int16_t)_mm_extract_epi16(p01, 1);
138  out_block->bsums[j / 16] = bsum;
139  }
140  }
141  }
142 }
Quantization block structures for weight-only quantization.
#define QK_K
void rmsnorm_q8_k_fused(const float *input, const float *gamma, void *vy, int tokens, int d_model, int aligned_embed_dim, float eps)
static float hmax256_ps_fused(__m256 v)
static float hsum256_ps_fused(__m256 v)
int8_t qs[256]
int16_t bsums[256/16]