33 #include <immintrin.h>
55 const int blocks_per_row = K /
QK4_0;
57 for (
int row = 0; row < M; row++) {
60 for (
int b = 0; b < blocks_per_row; b++) {
61 const block_q4_0 *block = &blocks[row * blocks_per_row + b];
63 const float *xp = &x[b *
QK4_0];
65 for (
int i = 0; i <
QK4_0 / 2; i++) {
66 const uint8_t packed = block->
qs[i];
67 const int8_t q0 = (packed & 0x0F) - 8;
68 const int8_t q1 = (packed >> 4) - 8;
70 sum += d * (float)q0 * xp[2*i + 0];
71 sum += d * (float)q1 * xp[2*i + 1];
83 void gemv_q4_0_avx512(
float *y,
89 const int blocks_per_row = K /
QK4_0;
90 const __m512i offset = _mm512_set1_epi32(8);
91 const __m512i mask_lo = _mm512_set1_epi32(0x0F);
93 for (
int row = 0; row < M; row++) {
94 __m512 acc = _mm512_setzero_ps();
96 for (
int b = 0; b < blocks_per_row; b++) {
97 const block_q4_0 *block = &blocks[row * blocks_per_row + b];
99 const float *xp = &x[b *
QK4_0];
102 __m128i packed = _mm_loadu_si128((
const __m128i *)block->
qs);
103 __m512i bytes = _mm512_cvtepu8_epi32(packed);
106 __m512i lo = _mm512_sub_epi32(_mm512_and_epi32(bytes, mask_lo), offset);
107 __m512i hi = _mm512_sub_epi32(_mm512_srli_epi32(bytes, 4), offset);
109 __m512 w_lo = _mm512_mul_ps(_mm512_cvtepi32_ps(lo), vscale);
110 __m512 w_hi = _mm512_mul_ps(_mm512_cvtepi32_ps(hi), vscale);
113 __m512 x_even = _mm512_set_ps(
114 xp[30], xp[28], xp[26], xp[24], xp[22], xp[20], xp[18], xp[16],
115 xp[14], xp[12], xp[10], xp[8], xp[6], xp[4], xp[2], xp[0]);
116 __m512 x_odd = _mm512_set_ps(
117 xp[31], xp[29], xp[27], xp[25], xp[23], xp[21], xp[19], xp[17],
118 xp[15], xp[13], xp[11], xp[9], xp[7], xp[5], xp[3], xp[1]);
120 acc = _mm512_fmadd_ps(w_lo, x_even, acc);
121 acc = _mm512_fmadd_ps(w_hi, x_odd, acc);
124 y[row] = _mm512_reduce_add_ps(acc);
138 gemv_q4_0_avx512(y, W, x, M, K);
156 for (
int n = 0; n < N; n++) {
157 gemv_q4_0(&Y[n * M], W, &X[n * K], M, K);
183 const int blocks_per_row = K /
QK4_0;
185 for (
int m = 0; m < M; m++) {
186 const float *a_row = &A[m * K];
188 for (
int n = 0; n < N; n++) {
191 for (
int b = 0; b < blocks_per_row; b++) {
192 const block_q4_0 *block = &blocks[n * blocks_per_row + b];
194 const float *ap = &a_row[b *
QK4_0];
196 for (
int i = 0; i <
QK4_0 / 2; i++) {
197 const uint8_t packed = block->
qs[i];
198 const int q0 = (packed & 0x0F) - 8;
199 const int q1 = (packed >> 4) - 8;
201 sum += d * (float)q0 * ap[2 * i + 0];
202 sum += d * (float)q1 * ap[2 * i + 1];
206 C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
236 const int blocks_per_row = K /
QK4_0;
239 memset(dX, 0, K *
sizeof(
float));
242 for (
int row = 0; row < M; row++) {
243 const float dy = dY[row];
245 for (
int b = 0; b < blocks_per_row; b++) {
246 const block_q4_0 *block = &blocks[row * blocks_per_row + b];
248 float *dxp = &dX[b *
QK4_0];
250 for (
int i = 0; i <
QK4_0 / 2; i++) {
251 const uint8_t packed = block->
qs[i];
252 const int8_t q0 = (packed & 0x0F) - 8;
253 const int8_t q1 = (packed >> 4) - 8;
255 dxp[2*i + 0] += d * (float)q0 * dy;
256 dxp[2*i + 1] += d * (float)q1 * dy;
266 void gemv_q4_0_backward_avx512(
float *dX,
272 const int blocks_per_row = K /
QK4_0;
273 const __m512i offset = _mm512_set1_epi32(8);
274 const __m512i mask_lo = _mm512_set1_epi32(0x0F);
277 memset(dX, 0, K *
sizeof(
float));
279 for (
int row = 0; row < M; row++) {
280 const __m512 vdy = _mm512_set1_ps(dY[row]);
282 for (
int b = 0; b < blocks_per_row; b++) {
283 const block_q4_0 *block = &blocks[row * blocks_per_row + b];
285 float *dxp = &dX[b *
QK4_0];
288 __m128i packed = _mm_loadu_si128((
const __m128i *)block->
qs);
289 __m512i bytes = _mm512_cvtepu8_epi32(packed);
291 __m512i lo = _mm512_sub_epi32(_mm512_and_epi32(bytes, mask_lo), offset);
292 __m512i hi = _mm512_sub_epi32(_mm512_srli_epi32(bytes, 4), offset);
294 __m512 w_lo = _mm512_mul_ps(_mm512_cvtepi32_ps(lo), vscale);
295 __m512 w_hi = _mm512_mul_ps(_mm512_cvtepi32_ps(hi), vscale);
298 __m512 grad_lo = _mm512_mul_ps(w_lo, vdy);
299 __m512 grad_hi = _mm512_mul_ps(w_hi, vdy);
302 float grad_lo_arr[16], grad_hi_arr[16];
303 _mm512_storeu_ps(grad_lo_arr, grad_lo);
304 _mm512_storeu_ps(grad_hi_arr, grad_hi);
306 for (
int i = 0; i < 16; i++) {
307 dxp[2*i + 0] += grad_lo_arr[i];
308 dxp[2*i + 1] += grad_hi_arr[i];
324 gemv_q4_0_backward_avx512(dX, W, dY, M, K);
338 for (
int n = 0; n < N; n++) {
347 float dot_q4_0(
const void *w_q4_0,
const float *x,
int K)
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
void gemm_nt_q4_0(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
void gemv_q4_0_backward(float *dX, const void *W, const float *dY, int M, int K)
Auto-dispatch backward.
void gemm_q4_0_backward(float *dX, const void *W, const float *dY, int M, int N, int K)
Batched backward pass.
float dot_q4_0(const void *w_q4_0, const float *x, int K)
void gemm_q4_0(float *Y, const void *W, const float *X, int M, int N, int K)
Matrix-matrix multiply with Q4_0 weights.
void gemv_q4_0_backward_ref(float *dX, const void *W, const float *dY, int M, int K)
Backward pass: compute input gradient.
void gemv_q4_0_ref(float *y, const void *W, const float *x, int M, int K)
Matrix-vector multiply with Q4_0 weights (scalar reference)
void gemv_q4_0(float *y, const void *W, const float *x, int M, int K)
Auto-dispatch GEMV.