39 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
40 #include <immintrin.h>
75 uint8_t *scale, uint8_t *min) {
77 *scale = scales[j] & 63;
78 *min = scales[j + 4] & 63;
80 *scale = (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4);
81 *min = (scales[j + 4] >> 4) | ((scales[j - 4] >> 6) << 4);
83 *scale = (scales[j - 4] & 0x0F) | ((scales[j - 8] >> 6) << 4);
84 *min = (scales[j - 4] >> 4) | ((scales[j - 8] >> 6) << 4);
92 void gemv_q5_k_ref(
float *y,
const void *W,
const float *x,
int M,
int K)
95 const int blocks_per_row = K /
QK_K;
97 for (
int m = 0; m < M; m++) {
98 const float *x_row = x;
101 for (
int b = 0; b < blocks_per_row; b++) {
102 const block_q5_K *block = &blocks[m * blocks_per_row + b];
105 const uint8_t *scales = block->
scales;
106 const uint8_t *qh = block->
qh;
107 const uint8_t *qs = block->
qs;
110 for (
int sb = 0; sb < 8; sb++) {
114 const float d_sub = d * (float)sc / 64.0f;
115 const float m_sub = dmin * (float)m / 64.0f;
118 const int qs_offset = sb * 16;
119 const int qh_offset = sb * 4;
121 for (
int i = 0; i < 32; i++) {
122 uint8_t qs_val = (qs[qs_offset + i/2] >> (4 * (i % 2))) & 0xF;
123 uint8_t qh_bit = (qh[qh_offset + i/8] >> (i % 8)) & 1;
124 uint8_t q = qs_val | (qh_bit << 4);
127 float w = d_sub * (float)q - m_sub;
128 sum += w * x_row[b *
QK_K + sb * 32 + i];
152 const int blocks_per_col = K /
QK_K;
154 for (
int m = 0; m < M; m++) {
155 const float *a_row = &A[m * K];
157 for (
int n = 0; n < N; n++) {
160 for (
int b = 0; b < blocks_per_col; b++) {
161 const block_q5_K *block = &blocks[n * blocks_per_col + b];
164 const uint8_t *scales = block->
scales;
165 const uint8_t *qh = block->
qh;
166 const uint8_t *qs = block->
qs;
169 for (
int sb = 0; sb < 8; sb++) {
173 const float d_sub = d * (float)sc / 64.0f;
174 const float m_sub = dmin * (float)m / 64.0f;
176 const int qs_offset = sb * 16;
177 const int qh_offset = sb * 4;
179 for (
int i = 0; i < 32; i++) {
180 uint8_t qs_val = (qs[qs_offset + i/2] >> (4 * (i % 2))) & 0xF;
181 uint8_t qh_bit = (qh[qh_offset + i/8] >> (i % 8)) & 1;
182 uint8_t q = qs_val | (qh_bit << 4);
184 float w = d_sub * (float)q - m_sub;
185 sum += w * a_row[b *
QK_K + sb * 32 + i];
190 C[m * N + n] = sum + (bias ? bias[n] : 0.0f);
199 void gemv_q5_k(
float *y,
const void *W,
const float *x,
int M,
int K)
201 #if defined(__AVX512F__)
204 #elif defined(__AVX2__)
207 #elif defined(__AVX__)
210 #elif defined(__SSE4_1__)
224 #if defined(__AVX512F__)
227 #elif defined(__AVX2__)
230 #elif defined(__AVX__)
233 #elif defined(__SSE4_1__)
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
void gemv_q5_k_ref(float *y, const void *W, const float *x, int M, int K)
static void get_q5_k_scale_min(int j, const uint8_t *scales, uint8_t *scale, uint8_t *min)
void gemv_q5_k(float *y, const void *W, const float *x, int M, int K)
void gemm_nt_q5_k_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q5_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)