31 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
32 #include <immintrin.h>
43 const int blocks_per_row = K /
QK_K;
46 for (
int b = 0; b < blocks_per_row; ++b) {
50 const uint8_t *ql = block->
ql;
51 const uint8_t *qh = block->
qh;
52 const int8_t *sc = block->
scales;
53 const float *xp = x + (size_t)b *
QK_K;
55 for (
int n = 0; n <
QK_K; n += 128) {
56 for (
int l = 0; l < 32; ++l) {
57 const int is = l / 16;
58 const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
59 const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
60 const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
61 const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
63 sum += (d * (float)sc[is + 0] * (
float)q1) * xp[l + 0];
64 sum += (d * (float)sc[is + 2] * (
float)q2) * xp[l + 32];
65 sum += (d * (float)sc[is + 4] * (
float)q3) * xp[l + 64];
66 sum += (d * (float)sc[is + 6] * (
float)q4) * xp[l + 96];
86 #if defined(__AVX__) && !defined(__AVX512F__)
91 const int blocks_per_row = K /
QK_K;
94 __m256 acc0 = _mm256_setzero_ps();
95 __m256 acc1 = _mm256_setzero_ps();
96 __m256 acc2 = _mm256_setzero_ps();
97 __m256 acc3 = _mm256_setzero_ps();
99 for (
int b = 0; b < blocks_per_row; ++b) {
103 const uint8_t *ql = block->
ql;
104 const uint8_t *qh = block->
qh;
105 const int8_t *sc = block->
scales;
106 const float *xp = x + (size_t)b *
QK_K;
109 for (
int n = 0; n <
QK_K; n += 128) {
111 for (
int l = 0; l < 32; l += 8) {
112 const int is = l / 16;
115 float dq0[8], dq1[8], dq2[8], dq3[8];
117 for (
int i = 0; i < 8; i++) {
119 const int8_t q1 = (int8_t)((ql[idx + 0] & 0xF) | (((qh[idx] >> 0) & 3) << 4)) - 32;
120 const int8_t q2 = (int8_t)((ql[idx + 32] & 0xF) | (((qh[idx] >> 2) & 3) << 4)) - 32;
121 const int8_t q3 = (int8_t)((ql[idx + 0] >> 4) | (((qh[idx] >> 4) & 3) << 4)) - 32;
122 const int8_t q4 = (int8_t)((ql[idx + 32] >> 4) | (((qh[idx] >> 6) & 3) << 4)) - 32;
124 const int is_i = (l + i) / 16;
125 dq0[i] = d * (float)sc[is_i + 0] * (
float)q1;
126 dq1[i] = d * (float)sc[is_i + 2] * (
float)q2;
127 dq2[i] = d * (float)sc[is_i + 4] * (
float)q3;
128 dq3[i] = d * (float)sc[is_i + 6] * (
float)q4;
131 __m256 vw0 = _mm256_loadu_ps(dq0);
132 __m256 vw1 = _mm256_loadu_ps(dq1);
133 __m256 vw2 = _mm256_loadu_ps(dq2);
134 __m256 vw3 = _mm256_loadu_ps(dq3);
136 __m256 vx0 = _mm256_loadu_ps(&xp[l + 0]);
137 __m256 vx1 = _mm256_loadu_ps(&xp[l + 32]);
138 __m256 vx2 = _mm256_loadu_ps(&xp[l + 64]);
139 __m256 vx3 = _mm256_loadu_ps(&xp[l + 96]);
141 acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(vw0, vx0));
142 acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(vw1, vx1));
143 acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(vw2, vx2));
144 acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(vw3, vx3));
154 __m256 sum01 = _mm256_add_ps(acc0, acc1);
155 __m256 sum23 = _mm256_add_ps(acc2, acc3);
156 __m256 sum = _mm256_add_ps(sum01, sum23);
159 __m128 hi = _mm256_extractf128_ps(sum, 1);
160 __m128 lo = _mm256_castps256_ps128(sum);
161 __m128 sum128 = _mm_add_ps(hi, lo);
162 sum128 = _mm_hadd_ps(sum128, sum128);
163 sum128 = _mm_hadd_ps(sum128, sum128);
165 return _mm_cvtss_f32(sum128);
174 if (!y || !W || !x) {
177 if (M <= 0 || K <= 0) {
183 const int blocks_per_row = K /
QK_K;
185 for (
int row = 0; row < M; ++row) {
186 const block_q6_K *w_row = blocks + (size_t)row * (
size_t)blocks_per_row;
187 #if defined(__AVX__) && !defined(__AVX512F__)
188 y[row] = dot_q6_k_avx(w_row, x, K);
200 if (!Y || !W || !X) {
203 if (M <= 0 || N <= 0 || K <= 0) {
207 for (
int n = 0; n < N; ++n) {
208 gemv_q6_k(&Y[n * M], W, &X[n * K], M, K);
218 if (!A || !B || !
C) {
221 if (M <= 0 || N <= 0 || K <= 0) {
234 for (
int i = 0; i < M; ++i) {
235 float *row =
C + (size_t)i * (
size_t)N;
236 for (
int j = 0; j < N; ++j) {
Quantization block structures for weight-only quantization.
#define GGML_FP16_TO_FP32
void gemm_q6_k(float *Y, const void *W, const float *X, int M, int N, int K)
void gemv_q6_k(float *y, const void *W, const float *x, int M, int K)
void gemm_nt_q6_k_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q6_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
static float dot_q6_k_ref(const block_q6_K *w, const float *x, int K)