34 #include <immintrin.h>
56 const int blocks_per_row = K /
QK4_1;
58 for (
int row = 0; row < M; row++) {
61 for (
int b = 0; b < blocks_per_row; b++) {
62 const block_q4_1 *block = &blocks[row * blocks_per_row + b];
65 const float *xp = &x[b *
QK4_1];
67 for (
int i = 0; i <
QK4_1 / 2; i++) {
68 const uint8_t packed = block->
qs[i];
69 const int q0 = (packed & 0x0F);
70 const int q1 = (packed >> 4);
73 const float w0 = d * (float)q0 + m;
74 const float w1 = d * (float)q1 + m;
76 sum += w0 * xp[2*i + 0];
77 sum += w1 * xp[2*i + 1];
89 void gemv_q4_1_avx512(
float *y,
95 const int blocks_per_row = K /
QK4_1;
96 const __m512i mask_lo = _mm512_set1_epi32(0x0F);
98 for (
int row = 0; row < M; row++) {
99 __m512 acc = _mm512_setzero_ps();
101 for (
int b = 0; b < blocks_per_row; b++) {
102 const block_q4_1 *block = &blocks[row * blocks_per_row + b];
105 const float *xp = &x[b *
QK4_1];
108 __m128i packed = _mm_loadu_si128((
const __m128i *)block->
qs);
109 __m512i bytes = _mm512_cvtepu8_epi32(packed);
112 __m512i lo = _mm512_and_epi32(bytes, mask_lo);
113 __m512i hi = _mm512_srli_epi32(bytes, 4);
116 __m512 w_lo = _mm512_fmadd_ps(_mm512_cvtepi32_ps(lo), vscale, vmin);
117 __m512 w_hi = _mm512_fmadd_ps(_mm512_cvtepi32_ps(hi), vscale, vmin);
120 __m512 x_even = _mm512_set_ps(
121 xp[30], xp[28], xp[26], xp[24], xp[22], xp[20], xp[18], xp[16],
122 xp[14], xp[12], xp[10], xp[8], xp[6], xp[4], xp[2], xp[0]);
123 __m512 x_odd = _mm512_set_ps(
124 xp[31], xp[29], xp[27], xp[25], xp[23], xp[21], xp[19], xp[17],
125 xp[15], xp[13], xp[11], xp[9], xp[7], xp[5], xp[3], xp[1]);
127 acc = _mm512_fmadd_ps(w_lo, x_even, acc);
128 acc = _mm512_fmadd_ps(w_hi, x_odd, acc);
131 y[row] = _mm512_reduce_add_ps(acc);
145 gemv_q4_1_avx512(y, W, x, M, K);
163 for (
int n = 0; n < N; n++) {
164 gemv_q4_1(&Y[n * M], W, &X[n * K], M, K);
187 const int blocks_per_row = K /
QK4_1;
190 memset(dX, 0, K *
sizeof(
float));
193 for (
int row = 0; row < M; row++) {
194 const float dy = dY[row];
196 for (
int b = 0; b < blocks_per_row; b++) {
197 const block_q4_1 *block = &blocks[row * blocks_per_row + b];
200 float *dxp = &dX[b *
QK4_1];
202 for (
int i = 0; i <
QK4_1 / 2; i++) {
203 const uint8_t packed = block->
qs[i];
204 const int q0 = (packed & 0x0F);
205 const int q1 = (packed >> 4);
207 const float w0 = d * (float)q0 + m;
208 const float w1 = d * (float)q1 + m;
210 dxp[2*i + 0] += w0 * dy;
211 dxp[2*i + 1] += w1 * dy;
236 for (
int n = 0; n < N; n++) {
263 const int blocks_per_row = K /
QK4_1;
265 for (
int m = 0; m < M; m++) {
266 const float *a_row = &A[m * K];
268 for (
int n = 0; n < N; n++) {
271 for (
int b = 0; b < blocks_per_row; b++) {
272 const block_q4_1 *block = &blocks[n * blocks_per_row + b];
275 const float *ap = &a_row[b *
QK4_1];
277 for (
int i = 0; i <
QK4_1 / 2; i++) {
278 const uint8_t packed = block->
qs[i];
279 const int q0 = (packed & 0x0F);
280 const int q1 = (packed >> 4);
282 const float w0 = d * (float)q0 + min;
283 const float w1 = d * (float)q1 + min;
285 sum += w0 * ap[2 * i + 0];
286 sum += w1 * ap[2 * i + 1];
290 C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
299 float dot_q4_1(
const void *w_q4_1,
const float *x,
int K)
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
void gemm_q4_1_backward(float *dX, const void *W, const float *dY, int M, int N, int K)
Batched backward pass.
void gemv_q4_1_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient.
void gemv_q4_1_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemv_q4_1_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q4_1 weights (scalar reference)
void gemm_nt_q4_1(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
GEMM with transposed Q4_1 weights: C = A @ B^T.
void gemv_q4_1(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.
void gemm_q4_1(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q4_1 weights.
float dot_q4_1(const void *w_q4_1, const float *x, int K)