39 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
40 #include <immintrin.h>
44 void gemv_q4_k_q8_k(
float *y,
const void *W,
const void *x_q8,
int M,
int K);
49 float val = fval + 12582912.f;
51 memcpy(&i, &val,
sizeof(
int));
52 return (i & 0x007fffff) - 0x00400000;
55 #if defined(__AVX__) && !defined(__AVX512F__)
57 __m128 hi = _mm256_extractf128_ps(v, 1);
58 __m128 lo = _mm256_castps256_ps128(v);
59 __m128 sum128 = _mm_add_ps(lo, hi);
60 sum128 = _mm_hadd_ps(sum128, sum128);
61 sum128 = _mm_hadd_ps(sum128, sum128);
62 return _mm_cvtss_f32(sum128);
90 if (!y || !x || !gamma || !W_q4k || M <= 0 || K <= 0) {
94 assert(K %
QK_K == 0);
95 const int nb = K /
QK_K;
100 assert(nb <= 32 &&
"K too large for stack buffer");
107 #if defined(__AVX512F__)
109 __m512 sum_sq_vec = _mm512_setzero_ps();
111 for (; d + 16 <= K; d += 16) {
112 __m512 xv = _mm512_loadu_ps(&x[d]);
113 sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
115 float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
117 sum_sq += x[d] * x[d];
120 #elif defined(__AVX__)
122 __m256 sum_sq_vec = _mm256_setzero_ps();
124 for (; d + 8 <= K; d += 8) {
125 __m256 xv = _mm256_loadu_ps(&x[d]);
126 __m256 xv_sq = _mm256_mul_ps(xv, xv);
127 sum_sq_vec = _mm256_add_ps(sum_sq_vec, xv_sq);
131 sum_sq += x[d] * x[d];
137 for (
int d = 0; d < K; ++d) {
138 double v = (double)x[d];
143 float mean_sq = (float)sum_sq / (
float)K;
144 float rstd = 1.0f / sqrtf(mean_sq + eps);
151 for (
int i = 0; i < nb; ++i) {
152 const float *x_block = x + i *
QK_K;
153 const float *g_block = gamma + i *
QK_K;
156 float max_val = 0.0f;
159 #if defined(__AVX512F__)
160 __m512 rstd_vec = _mm512_set1_ps(rstd);
161 __m512 max_vec = _mm512_setzero_ps();
162 __m512 sign_mask = _mm512_set1_ps(-0.0f);
164 for (
int j = 0; j <
QK_K; j += 16) {
165 __m512 xv = _mm512_loadu_ps(&x_block[j]);
166 __m512 gv = _mm512_loadu_ps(&g_block[j]);
167 __m512 norm = _mm512_mul_ps(_mm512_mul_ps(xv, rstd_vec), gv);
168 __m512 abs_norm = _mm512_andnot_ps(sign_mask, norm);
169 max_vec = _mm512_max_ps(max_vec, abs_norm);
172 __mmask16 gt_mask = _mm512_cmp_ps_mask(abs_norm, _mm512_set1_ps(amax), _CMP_GT_OQ);
174 float temp_amax = _mm512_reduce_max_ps(abs_norm);
175 if (temp_amax > amax) {
178 for (
int k = 0; k < 16; ++k) {
179 float v = x_block[j + k] * rstd * g_block[j + k];
180 if (fabsf(v) >= amax - 1e-6f) {
188 amax = _mm512_reduce_max_ps(max_vec);
190 #elif defined(__AVX__)
191 __m256 rstd_vec = _mm256_set1_ps(rstd);
193 for (
int j = 0; j <
QK_K; j += 8) {
194 __m256 xv = _mm256_loadu_ps(&x_block[j]);
195 __m256 gv = _mm256_loadu_ps(&g_block[j]);
196 __m256 norm = _mm256_mul_ps(_mm256_mul_ps(xv, rstd_vec), gv);
200 _mm256_storeu_ps(norm_arr, norm);
201 for (
int k = 0; k < 8; ++k) {
202 float av = fabsf(norm_arr[k]);
205 max_val = norm_arr[k];
211 for (
int j = 0; j <
QK_K; ++j) {
212 float norm = x_block[j] * rstd * g_block[j];
213 float av = fabsf(norm);
223 q8_buffer[i].
d = 0.0f;
224 memset(q8_buffer[i].qs, 0,
sizeof(q8_buffer[i].qs));
225 memset(q8_buffer[i].bsums, 0,
sizeof(q8_buffer[i].bsums));
230 const float iscale = -127.0f / max_val;
231 q8_buffer[i].
d = 1.0f / iscale;
234 for (
int j = 0; j <
QK_K; ++j) {
235 float norm = x_block[j] * rstd * g_block[j];
237 v = (v > 127) ? 127 : ((v < -128) ? -128 : v);
238 q8_buffer[i].
qs[j] = (int8_t)v;
242 for (
int j = 0; j <
QK_K / 16; ++j) {
244 const int8_t *qs = &q8_buffer[i].
qs[j * 16];
245 for (
int k = 0; k < 16; ++k) {
248 q8_buffer[i].
bsums[j] = (int16_t)sum;
273 if (!y || !x || !gamma || !W_q4k || M <= 0 || K <= 0) {
277 assert(K %
QK_K == 0);
281 if (K > 4096)
return;
283 float norm_out[4096];
288 for (
int d = 0; d < K; ++d) {
289 sum_sq += (double)x[d] * (
double)x[d];
291 float rstd = 1.0f / sqrtf((
float)(sum_sq / K) + eps);
293 for (
int d = 0; d < K; ++d) {
294 norm_out[d] = x[d] * rstd * gamma[d];
307 #ifdef FUSED_KERNEL_TEST
316 const int nb = K /
QK_K;
318 printf(
"Fused RMSNorm+Linear Test\n");
319 printf(
"K=%d, M=%d, blocks=%d\n", K, M, nb);
322 float *x = (
float *)aligned_alloc(64, K *
sizeof(
float));
323 float *gamma = (
float *)aligned_alloc(64, K *
sizeof(
float));
324 float *y_fused = (
float *)aligned_alloc(64, M *
sizeof(
float));
325 float *y_unfused = (
float *)aligned_alloc(64, M *
sizeof(
float));
329 for (
int i = 0; i < K; ++i) {
330 x[i] = (float)rand() / RAND_MAX * 2.0f - 1.0f;
331 gamma[i] = (float)rand() / RAND_MAX * 0.5f + 0.75f;
337 for (
int i = 0; i < M * nb; ++i) {
343 printf(
"Running fused version...\n");
346 printf(
"Running unfused version...\n");
350 float max_diff = 0.0f;
351 for (
int i = 0; i < M; ++i) {
352 float diff = fabsf(y_fused[i] - y_unfused[i]);
353 if (diff > max_diff) max_diff = diff;
356 printf(
"Max difference: %e\n", max_diff);
357 printf(
"Test %s\n", max_diff < 1e-3f ?
"PASSED" :
"FAILED");
365 return max_diff < 1e-3f ? 0 : 1;
int main(int argc, char **argv)
void quantize_row_q8_k(const float *x, void *y, int k)
Quantization block structures for weight-only quantization.
void unfused_rmsnorm_linear_q4k_ref(float *y, const float *x, const float *gamma, const void *W_q4k, int M, int K, float eps)
Reference (unfused) implementation for correctness testing.
void fused_rmsnorm_linear_q4k(float *y, const float *x, const float *gamma, const void *W_q4k, int M, int K, float eps)
Fused RMSNorm + Q4_K Linear projection.
static int ck_nearest_int_fused(float fval)
void gemv_q4_k_q8_k(float *y, const void *W, const void *x_q8, int M, int K)
static float hsum256_ps_fused(__m256 v)