36 #include <immintrin.h>
58 const int blocks_per_row = K /
QK5_1;
60 for (
int row = 0; row < M; row++) {
63 for (
int b = 0; b < blocks_per_row; b++) {
64 const block_q5_1 *block = &blocks[row * blocks_per_row + b];
67 const float *xp = &x[b *
QK5_1];
71 memcpy(&qh, block->
qh,
sizeof(qh));
78 for (
int j = 0; j <
QK5_1 / 2; j++) {
79 const int lo = (block->
qs[j] & 0x0F);
80 const int hi = ((qh >> j) & 1) << 4;
81 const float w = d * (float)(lo | hi) + m;
86 for (
int j = 0; j <
QK5_1 / 2; j++) {
87 const int lo = (block->
qs[j] >> 4);
88 const int hi = ((qh >> (j + 16)) & 1) << 4;
89 const float w = d * (float)(lo | hi) + m;
90 sum += w * xp[j +
QK5_1 / 2];
106 void gemv_q5_1_avx512(
float *y,
112 const int blocks_per_row = K /
QK5_1;
113 const __m512i mask_lo = _mm512_set1_epi32(0x0F);
115 for (
int row = 0; row < M; row++) {
116 __m512 acc = _mm512_setzero_ps();
118 for (
int b = 0; b < blocks_per_row; b++) {
119 const block_q5_1 *block = &blocks[row * blocks_per_row + b];
122 const float *xp = &x[b *
QK5_1];
126 memcpy(&qh, block->
qh,
sizeof(qh));
129 __m128i packed = _mm_loadu_si128((
const __m128i *)block->
qs);
130 __m512i bytes = _mm512_cvtepu8_epi32(packed);
133 __m512i lo = _mm512_and_epi32(bytes, mask_lo);
134 __m512i hi_shift = _mm512_srli_epi32(bytes, 4);
137 __m512i qh_first = _mm512_set_epi32(
138 ((qh >> 15) & 1) << 4, ((qh >> 14) & 1) << 4,
139 ((qh >> 13) & 1) << 4, ((qh >> 12) & 1) << 4,
140 ((qh >> 11) & 1) << 4, ((qh >> 10) & 1) << 4,
141 ((qh >> 9) & 1) << 4, ((qh >> 8) & 1) << 4,
142 ((qh >> 7) & 1) << 4, ((qh >> 6) & 1) << 4,
143 ((qh >> 5) & 1) << 4, ((qh >> 4) & 1) << 4,
144 ((qh >> 3) & 1) << 4, ((qh >> 2) & 1) << 4,
145 ((qh >> 1) & 1) << 4, ((qh >> 0) & 1) << 4
149 __m512i qh_second = _mm512_set_epi32(
150 ((qh >> 31) & 1) << 4, ((qh >> 30) & 1) << 4,
151 ((qh >> 29) & 1) << 4, ((qh >> 28) & 1) << 4,
152 ((qh >> 27) & 1) << 4, ((qh >> 26) & 1) << 4,
153 ((qh >> 25) & 1) << 4, ((qh >> 24) & 1) << 4,
154 ((qh >> 23) & 1) << 4, ((qh >> 22) & 1) << 4,
155 ((qh >> 21) & 1) << 4, ((qh >> 20) & 1) << 4,
156 ((qh >> 19) & 1) << 4, ((qh >> 18) & 1) << 4,
157 ((qh >> 17) & 1) << 4, ((qh >> 16) & 1) << 4
161 __m512i q_first = _mm512_or_epi32(lo, qh_first);
162 __m512i q_second = _mm512_or_epi32(hi_shift, qh_second);
165 __m512 w_first = _mm512_fmadd_ps(_mm512_cvtepi32_ps(q_first), vscale, vmin);
166 __m512 w_second = _mm512_fmadd_ps(_mm512_cvtepi32_ps(q_second), vscale, vmin);
169 __m512 x_first = _mm512_loadu_ps(&xp[0]);
170 __m512 x_second = _mm512_loadu_ps(&xp[16]);
172 acc = _mm512_fmadd_ps(w_first, x_first, acc);
173 acc = _mm512_fmadd_ps(w_second, x_second, acc);
176 y[row] = _mm512_reduce_add_ps(acc);
190 gemv_q5_1_avx512(y, W, x, M, K);
208 for (
int n = 0; n < N; n++) {
209 gemv_q5_1(&Y[n * M], W, &X[n * K], M, K);
232 const int blocks_per_row = K /
QK5_1;
235 memset(dX, 0, K *
sizeof(
float));
238 for (
int row = 0; row < M; row++) {
239 const float dy = dY[row];
241 for (
int b = 0; b < blocks_per_row; b++) {
242 const block_q5_1 *block = &blocks[row * blocks_per_row + b];
245 float *dxp = &dX[b *
QK5_1];
249 memcpy(&qh, block->
qh,
sizeof(qh));
252 for (
int j = 0; j <
QK5_1 / 2; j++) {
253 const int lo = (block->
qs[j] & 0x0F);
254 const int hi = ((qh >> j) & 1) << 4;
255 const float w = d * (float)(lo | hi) + m;
260 for (
int j = 0; j <
QK5_1 / 2; j++) {
261 const int lo = (block->
qs[j] >> 4);
262 const int hi = ((qh >> (j + 16)) & 1) << 4;
263 const float w = d * (float)(lo | hi) + m;
264 dxp[j +
QK5_1 / 2] += w * dy;
289 for (
int n = 0; n < N; n++) {
316 const int blocks_per_row = K /
QK5_1;
318 for (
int m = 0; m < M; m++) {
319 const float *a_row = &A[m * K];
321 for (
int n = 0; n < N; n++) {
324 for (
int b = 0; b < blocks_per_row; b++) {
325 const block_q5_1 *block = &blocks[n * blocks_per_row + b];
328 const float *ap = &a_row[b *
QK5_1];
331 memcpy(&qh, block->
qh,
sizeof(qh));
334 for (
int j = 0; j <
QK5_1 / 2; j++) {
335 const int lo = (block->
qs[j] & 0x0F);
336 const int hi = ((qh >> j) & 1) << 4;
337 sum += (d * (float)(lo | hi) + min) * ap[j];
341 for (
int j = 0; j <
QK5_1 / 2; j++) {
342 const int lo = (block->
qs[j] >> 4);
343 const int hi = ((qh >> (j + 16)) & 1) << 4;
344 sum += (d * (float)(lo | hi) + min) * ap[j +
QK5_1 / 2];
348 C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
357 float dot_q5_1(
const void *w_q5_1,
const float *x,
int K)
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
void gemm_q5_1_backward(float *dX, const void *W, const float *dY, int M, int N, int K)
Batched backward pass.
float dot_q5_1(const void *w_q5_1, const float *x, int K)
void gemv_q5_1(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.
void gemv_q5_1_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q5_1 weights (scalar reference)
void gemm_nt_q5_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q5_1 weights: C = A @ B^T.
void gemv_q5_1_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient.
void gemv_q5_1_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemm_q5_1(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q5_1 weights.