24 #include <immintrin.h>
33 static inline int32_t hsum256_epi32(__m256i v) {
34 __m128i lo = _mm256_castsi256_si128(v);
35 __m128i hi = _mm256_extracti128_si256(v, 1);
36 __m128i sum = _mm_add_epi32(lo, hi);
37 sum = _mm_hadd_epi32(sum, sum);
38 sum = _mm_hadd_epi32(sum, sum);
39 return _mm_cvtsi128_si32(sum);
42 static inline void load_q8_even_odd_16(
const int8_t *q8,
45 const __m128i q8_lo = _mm_loadu_si128((
const __m128i *)q8);
46 const __m128i q8_hi = _mm_loadu_si128((
const __m128i *)(q8 + 16));
47 const __m128i even_mask = _mm_setr_epi8(
48 0, 2, 4, 6, 8, 10, 12, 14,
49 (
char)0x80, (
char)0x80, (
char)0x80, (
char)0x80,
50 (
char)0x80, (
char)0x80, (
char)0x80, (
char)0x80);
51 const __m128i odd_mask = _mm_setr_epi8(
52 1, 3, 5, 7, 9, 11, 13, 15,
53 (
char)0x80, (
char)0x80, (
char)0x80, (
char)0x80,
54 (
char)0x80, (
char)0x80, (
char)0x80, (
char)0x80);
56 const __m128i q8_lo_even = _mm_shuffle_epi8(q8_lo, even_mask);
57 const __m128i q8_hi_even = _mm_shuffle_epi8(q8_hi, even_mask);
58 const __m128i q8_even = _mm_unpacklo_epi64(q8_lo_even, q8_hi_even);
60 const __m128i q8_lo_odd = _mm_shuffle_epi8(q8_lo, odd_mask);
61 const __m128i q8_hi_odd = _mm_shuffle_epi8(q8_hi, odd_mask);
62 const __m128i q8_odd = _mm_unpacklo_epi64(q8_lo_odd, q8_hi_odd);
64 *even16 = _mm256_cvtepi8_epi16(q8_even);
65 *odd16 = _mm256_cvtepi8_epi16(q8_odd);
68 static inline int32_t dot_q4_q8_32_avx2(
const uint8_t *q4,
70 const __m128i q4_packed = _mm_loadu_si128((
const __m128i *)q4);
71 const __m128i mask4 = _mm_set1_epi8(0x0F);
72 const __m128i q4_lo = _mm_and_si128(q4_packed, mask4);
73 const __m128i q4_hi = _mm_and_si128(_mm_srli_epi16(q4_packed, 4), mask4);
75 const __m256i q4_lo16 = _mm256_cvtepu8_epi16(q4_lo);
76 const __m256i q4_hi16 = _mm256_cvtepu8_epi16(q4_hi);
80 load_q8_even_odd_16(q8, &q8_even16, &q8_odd16);
82 const __m256i prod_lo = _mm256_madd_epi16(q4_lo16, q8_even16);
83 const __m256i prod_hi = _mm256_madd_epi16(q4_hi16, q8_odd16);
85 return hsum256_epi32(prod_lo) + hsum256_epi32(prod_hi);
Quantization block structures for weight-only quantization.
void gemv_q4_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)