23 #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
24 #include <immintrin.h>
37 #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
38 static inline int32_t hsum256_epi32(__m256i v) {
39 __m128i lo = _mm256_castsi256_si128(v);
40 __m128i hi = _mm256_extracti128_si256(v, 1);
41 __m128i sum = _mm_add_epi32(lo, hi);
42 sum = _mm_hadd_epi32(sum, sum);
43 sum = _mm_hadd_epi32(sum, sum);
44 return _mm_cvtsi128_si32(sum);
47 static inline void load_q8_even_odd_16(
const int8_t *q8,
50 const __m128i q8_lo = _mm_loadu_si128((
const __m128i *)q8);
51 const __m128i q8_hi = _mm_loadu_si128((
const __m128i *)(q8 + 16));
52 const __m128i even_mask = _mm_setr_epi8(
53 0, 2, 4, 6, 8, 10, 12, 14,
54 (
char)0x80, (
char)0x80, (
char)0x80, (
char)0x80,
55 (
char)0x80, (
char)0x80, (
char)0x80, (
char)0x80);
56 const __m128i odd_mask = _mm_setr_epi8(
57 1, 3, 5, 7, 9, 11, 13, 15,
58 (
char)0x80, (
char)0x80, (
char)0x80, (
char)0x80,
59 (
char)0x80, (
char)0x80, (
char)0x80, (
char)0x80);
61 const __m128i q8_lo_even = _mm_shuffle_epi8(q8_lo, even_mask);
62 const __m128i q8_hi_even = _mm_shuffle_epi8(q8_hi, even_mask);
63 *even8 = _mm_unpacklo_epi64(q8_lo_even, q8_hi_even);
65 const __m128i q8_lo_odd = _mm_shuffle_epi8(q8_lo, odd_mask);
66 const __m128i q8_hi_odd = _mm_shuffle_epi8(q8_hi, odd_mask);
67 *odd8 = _mm_unpacklo_epi64(q8_lo_odd, q8_hi_odd);
70 static inline int32_t dot_q4_q8_32_vnni(
const uint8_t *q4,
72 const __m128i q4_packed = _mm_loadu_si128((
const __m128i *)q4);
73 const __m256i q4_16 = _mm256_cvtepu8_epi16(q4_packed);
74 const __m256i mask4 = _mm256_set1_epi16(0x0F);
76 const __m256i q4_lo16 = _mm256_and_si256(q4_16, mask4);
77 const __m256i q4_hi16 = _mm256_and_si256(_mm256_srli_epi16(q4_16, 4), mask4);
79 const __m128i q4_lo8 = _mm_packus_epi16(_mm256_castsi256_si128(q4_lo16),
80 _mm256_extracti128_si256(q4_lo16, 1));
81 const __m128i q4_hi8 = _mm_packus_epi16(_mm256_castsi256_si128(q4_hi16),
82 _mm256_extracti128_si256(q4_hi16, 1));
83 const __m256i q4_bytes = _mm256_set_m128i(q4_hi8, q4_lo8);
87 load_q8_even_odd_16(q8, &q8_even8, &q8_odd8);
88 const __m256i q8_bytes = _mm256_set_m128i(q8_odd8, q8_even8);
89 __m256i acc = _mm256_setzero_si256();
90 acc = _mm256_dpbusd_epi32(acc, q4_bytes, q8_bytes);
91 return hsum256_epi32(acc);
Quantization block structures for weight-only quantization.
void gemv_q4_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_vnni(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)