43 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
44 #include <immintrin.h>
48 #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
49 #include <immintrin.h>
93 const int nb = K /
QK8_0;
97 for (
int m = 0; m < M; m++) {
98 const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
100 for (
int n = 0; n < N; n++) {
101 const block_q8_0 *b_row = b_blocks + (size_t)n * nb;
104 for (
int ib = 0; ib < nb; ib++) {
107 const float d = d_a * d_b;
110 for (
int j = 0; j <
QK8_0; j++) {
111 sumi += (int32_t)a_row[ib].qs[j] * (int32_t)b_row[ib].
qs[j];
114 sum += d * (float)sumi;
117 C[(size_t)m * N + n] = sum;
122 #if defined(__AVX2__)
130 void gemm_nt_q8_0_q8_0_avx2(
136 const int nb = K /
QK8_0;
140 for (
int m = 0; m < M; m++) {
141 const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
143 for (
int n = 0; n < N; n++) {
144 const block_q8_0 *b_row = b_blocks + (size_t)n * nb;
147 for (
int ib = 0; ib < nb; ib++) {
150 const float d = d_a * d_b;
153 __m256i va = _mm256_loadu_si256((
const __m256i *)a_row[ib].qs);
154 __m256i vb = _mm256_loadu_si256((
const __m256i *)b_row[ib].qs);
158 __m128i va_lo = _mm256_castsi256_si128(va);
159 __m128i va_hi = _mm256_extracti128_si256(va, 1);
160 __m128i vb_lo = _mm256_castsi256_si128(vb);
161 __m128i vb_hi = _mm256_extracti128_si256(vb, 1);
164 __m256i va_lo_16 = _mm256_cvtepi8_epi16(va_lo);
165 __m256i vb_lo_16 = _mm256_cvtepi8_epi16(vb_lo);
166 __m256i va_hi_16 = _mm256_cvtepi8_epi16(va_hi);
167 __m256i vb_hi_16 = _mm256_cvtepi8_epi16(vb_hi);
169 __m256i prod_lo = _mm256_mullo_epi16(va_lo_16, vb_lo_16);
170 __m256i prod_hi = _mm256_mullo_epi16(va_hi_16, vb_hi_16);
173 __m256i sum_lo = _mm256_madd_epi16(prod_lo, _mm256_set1_epi16(1));
174 __m256i sum_hi = _mm256_madd_epi16(prod_hi, _mm256_set1_epi16(1));
175 __m256i sum_32 = _mm256_add_epi32(sum_lo, sum_hi);
178 __m128i sum_128 = _mm_add_epi32(
179 _mm256_castsi256_si128(sum_32),
180 _mm256_extracti128_si256(sum_32, 1)
182 sum_128 = _mm_add_epi32(sum_128, _mm_srli_si128(sum_128, 8));
183 sum_128 = _mm_add_epi32(sum_128, _mm_srli_si128(sum_128, 4));
184 int32_t sumi = _mm_cvtsi128_si32(sum_128);
186 sum += d * (float)sumi;
189 C[(size_t)m * N + n] = sum;
195 #if defined(__AVX__) && !defined(__AVX2__)
203 void gemm_nt_q8_0_q8_0_avx(
209 const int nb = K /
QK8_0;
213 for (
int m = 0; m < M; m++) {
214 const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
215 for (
int n = 0; n < N; n++) {
216 const block_q8_0 *b_row = b_blocks + (size_t)n * nb;
219 for (
int ib = 0; ib < nb; ib++) {
222 const int8_t *a_qs = a_row[ib].
qs;
223 const int8_t *b_qs = b_row[ib].
qs;
226 __m128i d0 = _mm_madd_epi16(
227 _mm_cvtepi8_epi16(_mm_loadl_epi64((
const __m128i *)(a_qs + 0))),
228 _mm_cvtepi8_epi16(_mm_loadl_epi64((
const __m128i *)(b_qs + 0))));
229 __m128i d1 = _mm_madd_epi16(
230 _mm_cvtepi8_epi16(_mm_loadl_epi64((
const __m128i *)(a_qs + 8))),
231 _mm_cvtepi8_epi16(_mm_loadl_epi64((
const __m128i *)(b_qs + 8))));
232 __m128i d2 = _mm_madd_epi16(
233 _mm_cvtepi8_epi16(_mm_loadl_epi64((
const __m128i *)(a_qs + 16))),
234 _mm_cvtepi8_epi16(_mm_loadl_epi64((
const __m128i *)(b_qs + 16))));
235 __m128i d3 = _mm_madd_epi16(
236 _mm_cvtepi8_epi16(_mm_loadl_epi64((
const __m128i *)(a_qs + 24))),
237 _mm_cvtepi8_epi16(_mm_loadl_epi64((
const __m128i *)(b_qs + 24))));
240 __m128i s4 = _mm_add_epi32(_mm_add_epi32(d0, d1),
241 _mm_add_epi32(d2, d3));
242 s4 = _mm_add_epi32(s4, _mm_srli_si128(s4, 8));
243 s4 = _mm_add_epi32(s4, _mm_srli_si128(s4, 4));
244 sum += d * (float)_mm_cvtsi128_si32(s4);
246 C[(size_t)m * N + n] = sum;
252 #if defined(__AVX512F__)
259 void gemm_nt_q8_0_q8_0_avx512(
265 const int nb = K /
QK8_0;
269 for (
int m = 0; m < M; m++) {
270 const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
272 for (
int n = 0; n < N; n++) {
273 const block_q8_0 *b_row = b_blocks + (size_t)n * nb;
276 for (
int ib = 0; ib < nb; ib++) {
279 const float d = d_a * d_b;
282 __m256i va_256 = _mm256_loadu_si256((
const __m256i *)a_row[ib].qs);
283 __m256i vb_256 = _mm256_loadu_si256((
const __m256i *)b_row[ib].qs);
286 __m512i va_16 = _mm512_cvtepi8_epi16(va_256);
287 __m512i vb_16 = _mm512_cvtepi8_epi16(vb_256);
290 __m512i prod = _mm512_mullo_epi16(va_16, vb_16);
293 __m512i sum_32 = _mm512_madd_epi16(prod, _mm512_set1_epi16(1));
296 int32_t sumi = _mm512_reduce_add_epi32(sum_32);
298 sum += d * (float)sumi;
301 C[(size_t)m * N + n] = sum;
306 #if defined(__AVX512VNNI__)
316 void gemm_nt_q8_0_q8_0_vnni(
322 const int nb = K /
QK8_0;
326 for (
int m = 0; m < M; m++) {
327 const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
329 for (
int n = 0; n < N; n++) {
330 const block_q8_0 *b_row = b_blocks + (size_t)n * nb;
333 for (
int ib = 0; ib < nb; ib++) {
336 const float d = d_a * d_b;
339 __m256i va = _mm256_loadu_si256((
const __m256i *)a_row[ib].qs);
340 __m256i vb = _mm256_loadu_si256((
const __m256i *)b_row[ib].qs);
344 __m512i va_16 = _mm512_cvtepi8_epi16(va);
345 __m512i vb_16 = _mm512_cvtepi8_epi16(vb);
346 __m512i prod = _mm512_mullo_epi16(va_16, vb_16);
347 __m512i sum_32 = _mm512_madd_epi16(prod, _mm512_set1_epi16(1));
348 int32_t sumi = _mm512_reduce_add_epi32(sum_32);
350 sum += d * (float)sumi;
353 C[(size_t)m * N + n] = sum;
397 const int nb = K /
QK5_0;
401 for (
int m = 0; m < M; m++) {
402 const block_q8_0 *a_row = a_blocks + (size_t)m * nb;
404 for (
int n = 0; n < N; n++) {
405 const block_q5_0 *b_row = b_blocks + (size_t)n * nb;
408 for (
int ib = 0; ib < nb; ib++) {
411 const float d = d_a * d_b;
415 memcpy(&qh, b_row[ib].qh,
sizeof(qh));
420 for (
int j = 0; j < 16; j++) {
422 const uint8_t xh_0 = ((qh >> j) & 1) << 4;
423 const int8_t w0 = (int8_t)(((b_row[ib].qs[j] & 0x0F) | xh_0) - 16);
426 const uint8_t xh_1 = ((qh >> (j + 16)) & 1) << 4;
427 const int8_t w1 = (int8_t)(((b_row[ib].qs[j] >> 4) | xh_1) - 16);
430 sumi += (int32_t)w0 * (int32_t)a_row[ib].
qs[j];
431 sumi += (int32_t)w1 * (int32_t)a_row[ib].
qs[j + 16];
434 sum += d * (float)sumi;
437 C[(size_t)m * N + n] = sum;
465 uint8_t reserved[14];
470 static void amx_tile_config_init(
void)
472 static __thread
int initialized = 0;
473 if (initialized)
return;
475 tile_config_t tc = {0};
483 tc.rows[0] = 16; tc.colsb[0] = 64;
484 tc.rows[1] = 16; tc.colsb[1] = 64;
485 tc.rows[2] = 16; tc.colsb[2] = 64;
486 tc.rows[3] = 16; tc.colsb[3] = 64;
487 tc.rows[4] = 16; tc.colsb[4] = 64;
488 tc.rows[5] = 16; tc.colsb[5] = 64;
489 tc.rows[6] = 16; tc.colsb[6] = 64;
490 tc.rows[7] = 16; tc.colsb[7] = 64;
492 _tile_loadconfig(&tc);
505 void gemm_nt_q8_0_q8_0_amx(
511 amx_tile_config_init();
521 gemm_nt_q8_0_q8_0_avx512(A, B,
C, M, N, K);
524 void gemm_nt_q5_0_q8_0_amx(
530 amx_tile_config_init();
557 #elif defined(__AVX512VNNI__)
558 return "AVX-512 VNNI";
559 #elif defined(__AVX512F__)
561 #elif defined(__AVX2__)
563 #elif defined(__AVX__)
590 #if defined(__AVX512VNNI__)
591 gemm_nt_q8_0_q8_0_vnni(A, B,
C, M, N, K);
592 #elif defined(__AVX512F__)
593 gemm_nt_q8_0_q8_0_avx512(A, B,
C, M, N, K);
594 #elif defined(__AVX2__)
595 gemm_nt_q8_0_q8_0_avx2(A, B,
C, M, N, K);
596 #elif defined(__AVX__)
597 gemm_nt_q8_0_q8_0_avx(A, B,
C, M, N, K);
604 for (
int m = 0; m < M; m++) {
605 for (
int n = 0; n < N; n++) {
606 C[(size_t)m * N + n] += bias[n];
Quantization block structures for weight-only quantization.
#define CK_FP16_TO_FP32(x)
void gemm_nt_q5_0_q8_0_ref(const void *A, const void *B, float *C, int M, int N, int K)
Dispatcher for gemm_nt_q8_0_q8_0.
void gemm_nt_q8_0_q8_0_ref(const void *A, const void *B, float *C, int M, int N, int K)
Scalar reference: gemm_nt_q8_0_q8_0.
const char * gemm_batch_int8_impl_name(void)
Get the best implementation name for logging/debugging.
void gemm_nt_q8_0_q8_0(const void *A, const void *B, const float *bias, float *C, int M, int N, int K)
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)