← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q4k_q8k_vnni.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_q4k_q8k_vnni.c
3  * @brief VNNI Q4_K x Q8_K matvec kernel (inference only)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Requires AVX512-VNNI for vpdpbusd instruction.
15  */
16 
17 #include <stddef.h>
18 #include <stdint.h>
19 #include <string.h>
20 
21 #include "ckernel_quant.h"
22 
23 #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
24 #include <immintrin.h>
25 #endif
26 
27 void gemv_q4_k_q8_k_ref(float *y,
28  const void *W,
29  const void *x_q8,
30  int M, int K);
31 
32 void gemv_q4_k_q8_k_avx2(float *y,
33  const void *W,
34  const void *x_q8,
35  int M, int K);
36 
37 #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
38 static inline int32_t hsum256_epi32(__m256i v) {
39  __m128i lo = _mm256_castsi256_si128(v);
40  __m128i hi = _mm256_extracti128_si256(v, 1);
41  __m128i sum = _mm_add_epi32(lo, hi);
42  sum = _mm_hadd_epi32(sum, sum);
43  sum = _mm_hadd_epi32(sum, sum);
44  return _mm_cvtsi128_si32(sum);
45 }
46 
47 static inline void load_q8_even_odd_16(const int8_t *q8,
48  __m128i *even8,
49  __m128i *odd8) {
50  const __m128i q8_lo = _mm_loadu_si128((const __m128i *)q8);
51  const __m128i q8_hi = _mm_loadu_si128((const __m128i *)(q8 + 16));
52  const __m128i even_mask = _mm_setr_epi8(
53  0, 2, 4, 6, 8, 10, 12, 14,
54  (char)0x80, (char)0x80, (char)0x80, (char)0x80,
55  (char)0x80, (char)0x80, (char)0x80, (char)0x80);
56  const __m128i odd_mask = _mm_setr_epi8(
57  1, 3, 5, 7, 9, 11, 13, 15,
58  (char)0x80, (char)0x80, (char)0x80, (char)0x80,
59  (char)0x80, (char)0x80, (char)0x80, (char)0x80);
60 
61  const __m128i q8_lo_even = _mm_shuffle_epi8(q8_lo, even_mask);
62  const __m128i q8_hi_even = _mm_shuffle_epi8(q8_hi, even_mask);
63  *even8 = _mm_unpacklo_epi64(q8_lo_even, q8_hi_even);
64 
65  const __m128i q8_lo_odd = _mm_shuffle_epi8(q8_lo, odd_mask);
66  const __m128i q8_hi_odd = _mm_shuffle_epi8(q8_hi, odd_mask);
67  *odd8 = _mm_unpacklo_epi64(q8_lo_odd, q8_hi_odd);
68 }
69 
70 static inline int32_t dot_q4_q8_32_vnni(const uint8_t *q4,
71  const int8_t *q8) {
72  const __m128i q4_packed = _mm_loadu_si128((const __m128i *)q4);
73  const __m256i q4_16 = _mm256_cvtepu8_epi16(q4_packed);
74  const __m256i mask4 = _mm256_set1_epi16(0x0F);
75 
76  const __m256i q4_lo16 = _mm256_and_si256(q4_16, mask4);
77  const __m256i q4_hi16 = _mm256_and_si256(_mm256_srli_epi16(q4_16, 4), mask4);
78 
79  const __m128i q4_lo8 = _mm_packus_epi16(_mm256_castsi256_si128(q4_lo16),
80  _mm256_extracti128_si256(q4_lo16, 1));
81  const __m128i q4_hi8 = _mm_packus_epi16(_mm256_castsi256_si128(q4_hi16),
82  _mm256_extracti128_si256(q4_hi16, 1));
83  const __m256i q4_bytes = _mm256_set_m128i(q4_hi8, q4_lo8);
84 
85  __m128i q8_even8;
86  __m128i q8_odd8;
87  load_q8_even_odd_16(q8, &q8_even8, &q8_odd8);
88  const __m256i q8_bytes = _mm256_set_m128i(q8_odd8, q8_even8);
89  __m256i acc = _mm256_setzero_si256();
90  acc = _mm256_dpbusd_epi32(acc, q4_bytes, q8_bytes);
91  return hsum256_epi32(acc);
92 }
93 #endif
94 
95 void gemv_q4_k_q8_k_vnni(float *y,
96  const void *W,
97  const void *x_q8,
98  int M, int K)
99 {
100  /* TODO: Implement VNNI version with correct Q4_K memory layout.
101  * For now, fall back to reference implementation which has been
102  * fixed to use the correct layout.
103  */
104  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
105 }
Quantization block structures for weight-only quantization.
void gemv_q4_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_vnni(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)