18 #pragma GCC target("avx,sse4.1")
19 #include <immintrin.h>
28 __m128 hi = _mm256_extractf128_ps(v, 1);
29 __m128 lo = _mm256_castps256_ps128(v);
30 __m128 sum128 = _mm_add_ps(lo, hi);
31 sum128 = _mm_hadd_ps(sum128, sum128);
32 sum128 = _mm_hadd_ps(sum128, sum128);
33 return _mm_cvtss_f32(sum128);
38 __m128 hi = _mm256_extractf128_ps(v, 1);
39 __m128 lo = _mm256_castps256_ps128(v);
40 __m128 max128 = _mm_max_ps(lo, hi);
41 max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, _MM_SHUFFLE(1, 0, 3, 2)));
42 max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, _MM_SHUFFLE(0, 1, 0, 1)));
43 return _mm_cvtss_f32(max128);
59 int aligned_embed_dim,
63 const int D = d_model;
66 for (
int t = 0; t < T; ++t) {
67 const float *x = input + (size_t)t * aligned_embed_dim;
70 __m256 sum_sq_vec = _mm256_setzero_ps();
71 for (
int d = 0; d < D; d += 8) {
72 __m256 xv = _mm256_loadu_ps(&x[d]);
73 sum_sq_vec = _mm256_add_ps(sum_sq_vec, _mm256_mul_ps(xv, xv));
76 float rstd = 1.0f / sqrtf(sum_sq / (
float)D + eps);
77 __m256 vrstd = _mm256_set1_ps(rstd);
82 for (
int b = 0; b < D /
QK_K; ++b) {
83 const float *xb = x + b *
QK_K;
84 const float *gb = gamma + b *
QK_K;
88 __m256 v_max_abs = _mm256_setzero_ps();
91 for (
int d = 0; d <
QK_K; d += 8) {
92 __m256 xv = _mm256_loadu_ps(&xb[d]);
93 __m256 gv = _mm256_loadu_ps(&gb[d]);
94 __m256 normalized = _mm256_mul_ps(_mm256_mul_ps(xv, vrstd), gv);
96 _mm256_storeu_ps(&norm_buf[d], normalized);
98 __m256 v_abs = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), normalized);
99 v_max_abs = _mm256_max_ps(v_max_abs, v_abs);
103 if (max_val == 0.0f) {
105 memset(out_block->
qs, 0,
QK_K);
106 memset(out_block->
bsums, 0,
sizeof(out_block->
bsums));
111 float iscale = -127.0f / max_val;
112 __m256 v_iscale = _mm256_set1_ps(iscale);
113 out_block->
d = 1.0f / iscale;
115 for (
int j = 0; j <
QK_K; j += 16) {
117 __m128 n0 = _mm_loadu_ps(&norm_buf[j + 0]);
118 __m128 n1 = _mm_loadu_ps(&norm_buf[j + 4]);
119 __m128 n2 = _mm_loadu_ps(&norm_buf[j + 8]);
120 __m128 n3 = _mm_loadu_ps(&norm_buf[j + 12]);
122 __m128i q0 = _mm_cvtps_epi32(_mm_mul_ps(n0, _mm256_castps256_ps128(v_iscale)));
123 __m128i q1 = _mm_cvtps_epi32(_mm_mul_ps(n1, _mm256_castps256_ps128(v_iscale)));
124 __m128i q2 = _mm_cvtps_epi32(_mm_mul_ps(n2, _mm256_castps256_ps128(v_iscale)));
125 __m128i q3 = _mm_cvtps_epi32(_mm_mul_ps(n3, _mm256_castps256_ps128(v_iscale)));
127 __m128i q01 = _mm_packs_epi32(q0, q1);
128 __m128i q23 = _mm_packs_epi32(q2, q3);
129 __m128i q0123 = _mm_packs_epi16(q01, q23);
131 _mm_storeu_si128((__m128i *)(out_block->
qs + j), q0123);
134 __m128i p01 = _mm_add_epi16(q01, q23);
135 p01 = _mm_add_epi16(p01, _mm_shuffle_epi32(p01, _MM_SHUFFLE(1, 0, 3, 2)));
136 p01 = _mm_add_epi16(p01, _mm_shufflelo_epi16(p01, _MM_SHUFFLE(1, 0, 3, 2)));
137 int16_t bsum = (int16_t)_mm_extract_epi16(p01, 0) + (int16_t)_mm_extract_epi16(p01, 1);
138 out_block->
bsums[j / 16] = bsum;
Quantization block structures for weight-only quantization.
void rmsnorm_q8_k_fused(const float *input, const float *gamma, void *vy, int tokens, int d_model, int aligned_embed_dim, float eps)
static float hmax256_ps_fused(__m256 v)
static float hsum256_ps_fused(__m256 v)