33 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
34 #include <immintrin.h>
59 const int blocks_per_row = K /
QK_K;
61 for (
int row = 0; row < M; row++) {
64 for (
int b = 0; b < blocks_per_row; b++) {
65 const block_q4_K *block = &blocks[row * blocks_per_row + b];
78 for (
int iter = 0; iter < 4; iter++) {
79 const float d1 = d * (float)sc[2*iter];
80 const float m1 = dmin * (float)m[2*iter];
81 const float d2 = d * (float)sc[2*iter + 1];
82 const float m2 = dmin * (float)m[2*iter + 1];
83 const uint8_t *qs = &block->
qs[iter * 32];
84 const float *xp = &x[b *
QK_K + iter * 64];
87 for (
int l = 0; l < 32; l++) {
88 const int8_t q = (qs[l] & 0x0F);
89 sum += (d1 * (float)q - m1) * xp[l];
92 for (
int l = 0; l < 32; l++) {
93 const int8_t q = (qs[l] >> 4);
94 sum += (d2 * (float)q - m2) * xp[l + 32];
109 void gemv_q4_k_avx512(
float *y,
115 const int blocks_per_row = K /
QK_K;
117 for (
int row = 0; row < M; row++) {
118 __m512 acc = _mm512_setzero_ps();
120 for (
int b = 0; b < blocks_per_row; b++) {
121 const block_q4_K *block = &blocks[row * blocks_per_row + b];
125 uint8_t sc[8], m_arr[8];
128 const __m512i mask_lo = _mm512_set1_epi32(0x0F);
133 for (
int iter = 0; iter < 4; iter++) {
134 const float d1 = d * (float)sc[2*iter];
135 const float m1 = dmin * (float)m_arr[2*iter];
136 const float d2 = d * (float)sc[2*iter + 1];
137 const float m2 = dmin * (float)m_arr[2*iter + 1];
139 const __m512 vscale1 = _mm512_set1_ps(d1);
140 const __m512 vmin1 = _mm512_set1_ps(m1);
141 const __m512 vscale2 = _mm512_set1_ps(d2);
142 const __m512 vmin2 = _mm512_set1_ps(m2);
144 const uint8_t *qs = &block->
qs[iter * 32];
145 const float *xp = &x[b *
QK_K + iter * 64];
149 for (
int chunk = 0; chunk < 2; chunk++) {
150 __m128i packed = _mm_loadu_si128((
const __m128i *)&qs[chunk * 16]);
151 __m512i bytes = _mm512_cvtepu8_epi32(packed);
152 __m512i lo = _mm512_and_epi32(bytes, mask_lo);
154 __m512 w = _mm512_fnmadd_ps(_mm512_set1_ps(1.0f), vmin1,
155 _mm512_mul_ps(_mm512_cvtepi32_ps(lo), vscale1));
156 __m512 x_vec = _mm512_loadu_ps(&xp[chunk * 16]);
157 acc = _mm512_fmadd_ps(w, x_vec, acc);
161 for (
int chunk = 0; chunk < 2; chunk++) {
162 __m128i packed = _mm_loadu_si128((
const __m128i *)&qs[chunk * 16]);
163 __m512i bytes = _mm512_cvtepu8_epi32(packed);
164 __m512i hi = _mm512_srli_epi32(bytes, 4);
166 __m512 w = _mm512_fnmadd_ps(_mm512_set1_ps(1.0f), vmin2,
167 _mm512_mul_ps(_mm512_cvtepi32_ps(hi), vscale2));
168 __m512 x_vec = _mm512_loadu_ps(&xp[32 + chunk * 16]);
169 acc = _mm512_fmadd_ps(w, x_vec, acc);
175 y[row] = _mm512_reduce_add_ps(acc);
188 #if defined(__AVX__) && !defined(__AVX512F__)
195 void gemv_q4_k_avx(
float *y,
201 const int blocks_per_row = K /
QK_K;
203 for (
int row = 0; row < M; row++) {
205 __m256 acc0 = _mm256_setzero_ps();
206 __m256 acc1 = _mm256_setzero_ps();
207 __m256 acc2 = _mm256_setzero_ps();
208 __m256 acc3 = _mm256_setzero_ps();
210 for (
int b = 0; b < blocks_per_row; b++) {
211 const block_q4_K *block = &blocks[row * blocks_per_row + b];
216 uint8_t sc[8], m_arr[8];
220 for (
int iter = 0; iter < 4; iter++) {
221 const float d1 = d * (float)sc[2*iter];
222 const float m1 = dmin * (float)m_arr[2*iter];
223 const float d2 = d * (float)sc[2*iter + 1];
224 const float m2 = dmin * (float)m_arr[2*iter + 1];
225 const uint8_t *qs = &block->
qs[iter * 32];
226 const float *xp = &x[b *
QK_K + iter * 64];
229 __m256 vd1 = _mm256_set1_ps(d1);
230 __m256 vm1 = _mm256_set1_ps(m1);
231 __m256 vd2 = _mm256_set1_ps(d2);
232 __m256 vm2 = _mm256_set1_ps(m2);
235 for (
int g = 0; g < 4; g++) {
238 for (
int i = 0; i < 8; i++) {
239 dq[i] = d1 * (float)(qs[g*8 + i] & 0x0F) - m1;
241 __m256 vw = _mm256_loadu_ps(dq);
242 __m256 vx = _mm256_loadu_ps(&xp[g*8]);
245 __m256 prod = _mm256_mul_ps(vw, vx);
246 acc0 = _mm256_add_ps(acc0, prod);
250 for (
int g = 0; g < 4; g++) {
253 for (
int i = 0; i < 8; i++) {
254 dq[i] = d2 * (float)(qs[g*8 + i] >> 4) - m2;
256 __m256 vw = _mm256_loadu_ps(dq);
257 __m256 vx = _mm256_loadu_ps(&xp[32 + g*8]);
259 __m256 prod = _mm256_mul_ps(vw, vx);
260 acc1 = _mm256_add_ps(acc1, prod);
266 __m256 sum01 = _mm256_add_ps(acc0, acc1);
267 __m256 sum23 = _mm256_add_ps(acc2, acc3);
268 __m256 sum = _mm256_add_ps(sum01, sum23);
271 __m128 hi = _mm256_extractf128_ps(sum, 1);
272 __m128 lo = _mm256_castps256_ps128(sum);
273 __m128 sum128 = _mm_add_ps(hi, lo);
274 sum128 = _mm_hadd_ps(sum128, sum128);
275 sum128 = _mm_hadd_ps(sum128, sum128);
277 y[row] = _mm_cvtss_f32(sum128);
291 gemv_q4_k_avx512(y, W, x, M, K);
292 #elif defined(__AVX__)
293 gemv_q4_k_avx(y, W, x, M, K);
323 for (
int n = 0; n < N; n++) {
324 gemv_q4_k(&Y[n * M], W, &X[n * K], M, K);
334 void gemm_q4_k_avx512(
float *Y,
340 const int blocks_per_row = K /
QK_K;
343 const int N4 = N / 4 * 4;
345 for (
int row = 0; row < M; row++) {
347 for (
int n = 0; n < N4; n += 4) {
348 __m512 acc0 = _mm512_setzero_ps();
349 __m512 acc1 = _mm512_setzero_ps();
350 __m512 acc2 = _mm512_setzero_ps();
351 __m512 acc3 = _mm512_setzero_ps();
353 for (
int b = 0; b < blocks_per_row; b++) {
354 const block_q4_K *block = &blocks[row * blocks_per_row + b];
358 uint8_t sc[8], m_arr[8];
361 for (
int sub = 0; sub < 8; sub++) {
362 const float scale = d * (float)sc[sub];
363 const float min_val = dmin * (float)m_arr[sub];
364 const __m512 vscale = _mm512_set1_ps(scale);
365 const __m512 vmin = _mm512_set1_ps(min_val);
366 const __m512i offset = _mm512_set1_epi32(8);
367 const __m512i mask_lo = _mm512_set1_epi32(0x0F);
369 const uint8_t *qs = &block->
qs[sub * 16];
370 const int x_offset = b *
QK_K + sub * 32;
373 __m128i packed = _mm_loadu_si128((
const __m128i *)qs);
374 __m512i bytes = _mm512_cvtepu8_epi32(packed);
376 __m512i lo = _mm512_sub_epi32(_mm512_and_epi32(bytes, mask_lo), offset);
377 __m512i hi = _mm512_sub_epi32(_mm512_srli_epi32(bytes, 4), offset);
379 __m512 w_lo = _mm512_fmadd_ps(_mm512_cvtepi32_ps(lo), vscale, vmin);
380 __m512 w_hi = _mm512_fmadd_ps(_mm512_cvtepi32_ps(hi), vscale, vmin);
384 for (
int bn = 0; bn < 4; bn++) {
385 const float *xp = &X[(n + bn) * K + x_offset];
387 __m512 x_even = _mm512_set_ps(
388 xp[30], xp[28], xp[26], xp[24], xp[22], xp[20], xp[18], xp[16],
389 xp[14], xp[12], xp[10], xp[8], xp[6], xp[4], xp[2], xp[0]);
390 __m512 x_odd = _mm512_set_ps(
391 xp[31], xp[29], xp[27], xp[25], xp[23], xp[21], xp[19], xp[17],
392 xp[15], xp[13], xp[11], xp[9], xp[7], xp[5], xp[3], xp[1]);
394 __m512 *acc = (bn == 0) ? &acc0 : (bn == 1) ? &acc1 :
395 (bn == 2) ? &acc2 : &acc3;
396 *acc = _mm512_fmadd_ps(w_lo, x_even, *acc);
397 *acc = _mm512_fmadd_ps(w_hi, x_odd, *acc);
402 Y[(n + 0) * M + row] = _mm512_reduce_add_ps(acc0);
403 Y[(n + 1) * M + row] = _mm512_reduce_add_ps(acc1);
404 Y[(n + 2) * M + row] = _mm512_reduce_add_ps(acc2);
405 Y[(n + 3) * M + row] = _mm512_reduce_add_ps(acc3);
409 for (
int n = N4; n < N; n++) {
410 __m512 acc = _mm512_setzero_ps();
412 for (
int b = 0; b < blocks_per_row; b++) {
413 const block_q4_K *block = &blocks[row * blocks_per_row + b];
417 uint8_t sc[8], m_arr[8];
420 for (
int sub = 0; sub < 8; sub++) {
421 const float scale = d * (float)sc[sub];
422 const float min_val = dmin * (float)m_arr[sub];
423 const __m512 vscale = _mm512_set1_ps(scale);
424 const __m512 vmin = _mm512_set1_ps(min_val);
425 const __m512i offset = _mm512_set1_epi32(8);
426 const __m512i mask_lo = _mm512_set1_epi32(0x0F);
428 const uint8_t *qs = &block->
qs[sub * 16];
429 const float *xp = &X[n * K + b *
QK_K + sub * 32];
431 __m128i packed = _mm_loadu_si128((
const __m128i *)qs);
432 __m512i bytes = _mm512_cvtepu8_epi32(packed);
434 __m512i lo = _mm512_sub_epi32(_mm512_and_epi32(bytes, mask_lo), offset);
435 __m512i hi = _mm512_sub_epi32(_mm512_srli_epi32(bytes, 4), offset);
437 __m512 w_lo = _mm512_fmadd_ps(_mm512_cvtepi32_ps(lo), vscale, vmin);
438 __m512 w_hi = _mm512_fmadd_ps(_mm512_cvtepi32_ps(hi), vscale, vmin);
440 __m512 x_even = _mm512_set_ps(
441 xp[30], xp[28], xp[26], xp[24], xp[22], xp[20], xp[18], xp[16],
442 xp[14], xp[12], xp[10], xp[8], xp[6], xp[4], xp[2], xp[0]);
443 __m512 x_odd = _mm512_set_ps(
444 xp[31], xp[29], xp[27], xp[25], xp[23], xp[21], xp[19], xp[17],
445 xp[15], xp[13], xp[11], xp[9], xp[7], xp[5], xp[3], xp[1]);
447 acc = _mm512_fmadd_ps(w_lo, x_even, acc);
448 acc = _mm512_fmadd_ps(w_hi, x_odd, acc);
452 Y[n * M + row] = _mm512_reduce_add_ps(acc);
484 float dot_q4_k(
const void *w_q4k,
const float *x,
int K)
517 const int blocks_per_row = K /
QK_K;
520 memset(dX, 0, K *
sizeof(
float));
524 for (
int row = 0; row < M; row++) {
525 const float dy = dY[row];
527 for (
int b = 0; b < blocks_per_row; b++) {
528 const block_q4_K *block = &blocks[row * blocks_per_row + b];
536 for (
int iter = 0; iter < 4; iter++) {
537 const float d1 = d * (float)sc[2 * iter];
538 const float m1 = dmin * (float)m[2 * iter];
539 const float d2 = d * (float)sc[2 * iter + 1];
540 const float m2 = dmin * (float)m[2 * iter + 1];
542 const uint8_t *qs = &block->
qs[iter * 32];
543 float *dxp = &dX[b *
QK_K + iter * 64];
546 for (
int l = 0; l < 32; l++) {
547 const int q = (qs[l] & 0x0F);
548 const float w = d1 * (float)q - m1;
553 for (
int l = 0; l < 32; l++) {
554 const int q = (qs[l] >> 4);
555 const float w = d2 * (float)q - m2;
556 dxp[32 + l] += w * dy;
569 void gemv_q4_k_backward_avx512(
float *dX,
575 const int blocks_per_row = K /
QK_K;
576 const __m512i mask_lo = _mm512_set1_epi32(0x0F);
579 memset(dX, 0, K *
sizeof(
float));
581 for (
int row = 0; row < M; row++) {
582 const __m512 vdy = _mm512_set1_ps(dY[row]);
584 for (
int b = 0; b < blocks_per_row; b++) {
585 const block_q4_K *block = &blocks[row * blocks_per_row + b];
589 uint8_t sc[8], m_arr[8];
593 for (
int iter = 0; iter < 4; iter++) {
594 const float d1 = d * (float)sc[2 * iter];
595 const float m1 = dmin * (float)m_arr[2 * iter];
596 const float d2 = d * (float)sc[2 * iter + 1];
597 const float m2 = dmin * (float)m_arr[2 * iter + 1];
599 const __m512 vd1 = _mm512_set1_ps(d1);
600 const __m512 vm1 = _mm512_set1_ps(m1);
601 const __m512 vd2 = _mm512_set1_ps(d2);
602 const __m512 vm2 = _mm512_set1_ps(m2);
604 const uint8_t *qs = &block->
qs[iter * 32];
605 float *dxp = &dX[b *
QK_K + iter * 64];
608 for (
int chunk = 0; chunk < 2; chunk++) {
609 __m128i packed = _mm_loadu_si128((
const __m128i *)&qs[chunk * 16]);
610 __m512i bytes = _mm512_cvtepu8_epi32(packed);
611 __m512i lo = _mm512_and_epi32(bytes, mask_lo);
613 __m512 w = _mm512_fnmadd_ps(_mm512_set1_ps(1.0f), vm1,
614 _mm512_mul_ps(_mm512_cvtepi32_ps(lo), vd1));
615 __m512 grad = _mm512_mul_ps(w, vdy);
616 __m512 existing = _mm512_loadu_ps(&dxp[chunk * 16]);
617 _mm512_storeu_ps(&dxp[chunk * 16], _mm512_add_ps(existing, grad));
621 for (
int chunk = 0; chunk < 2; chunk++) {
622 __m128i packed = _mm_loadu_si128((
const __m128i *)&qs[chunk * 16]);
623 __m512i bytes = _mm512_cvtepu8_epi32(packed);
624 __m512i hi = _mm512_srli_epi32(bytes, 4);
626 __m512 w = _mm512_fnmadd_ps(_mm512_set1_ps(1.0f), vm2,
627 _mm512_mul_ps(_mm512_cvtepi32_ps(hi), vd2));
628 __m512 grad = _mm512_mul_ps(w, vdy);
629 __m512 existing = _mm512_loadu_ps(&dxp[32 + chunk * 16]);
630 _mm512_storeu_ps(&dxp[32 + chunk * 16], _mm512_add_ps(existing, grad));
647 gemv_q4_k_backward_avx512(dX, W, dY, M, K);
661 for (
int n = 0; n < N; n++) {
689 if (!A || !B || !
C) {
692 if (M <= 0 || N <= 0 || K <= 0) {
705 for (
int i = 0; i < M; ++i) {
706 float *row =
C + (size_t)i * (
size_t)N;
707 for (
int j = 0; j < N; ++j) {
Quantization block structures for weight-only quantization.
#define GGML_FP16_TO_FP32
#define CK_FP16_TO_FP32(x)
static void unpack_q4_k_scales(const uint8_t *scales, uint8_t *sc, uint8_t *m)
Unpack Q4_K sub-block scales and mins.
void gemm_q4_k_ref(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q4_K weights (scalar reference)
float dot_q4_k(const void *w_q4k, const float *x, int K)
Compute dot product of Q4_K row with FP32 vector.
void gemm_q4_k_backward(float *dX, const void *W, const float *dY, int M, int N, int K)
Batched backward pass.
void gemm_nt_q4_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemv_q4_k_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemv_q4_k_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q4_K weights (scalar reference)
void gemv_q4_k_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient (scalar reference)
void gemv_q4_k(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV based on available SIMD.
void gemm_q4_k(float *Y, const void *W, const float *X, int M, int N, int K)
Auto-dispatch GEMM based on available SIMD.