← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_amx.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_amx.c
3  * @brief AMX (Advanced Matrix Extensions) GEMM kernels
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Intel AMX provides dedicated matrix multiply hardware:
15  * - 8 tile registers (TMM0-TMM7), each up to 1KB
16  * - TDPBSSD: INT8 signed dot product (A signed, B signed)
17  * - TDPBSUD: INT8 mixed sign (A signed, B unsigned)
18  * - TDPBUSD: INT8 mixed sign (A unsigned, B signed)
19  * - TDPBUUD: INT8 unsigned dot product
20  * - TDPBF16PS: BF16 dot product to FP32
21  *
22  * Tile dimensions:
23  * - Max: 16 rows x 64 bytes (1024 bytes per tile)
24  * - For INT8: 16x64 elements
25  * - For BF16: 16x32 elements
26  *
27  * Performance:
28  * - AMX INT8: ~2000 INT8 ops/cycle (vs ~256 for AVX-512 VNNI)
29  * - AMX BF16: ~1000 BF16 ops/cycle
30  * - Expected 8-16x speedup over AVX-512 for large GEMM
31  *
32  * Requirements:
33  * - Sapphire Rapids or newer (4th Gen Xeon)
34  * - Linux kernel 5.16+ with AMX support
35  * - Compiler: GCC 11+, Clang 12+, ICX 2022+
36  */
37 
38 #include <stdint.h>
39 #include <stddef.h>
40 #include <string.h>
41 #include <stdbool.h>
42 
43 #include "ckernel_quant.h"
44 
45 /* AMX requires specific compiler support */
46 #if defined(__AMX_INT8__) || defined(__AMX_TILE__)
47 
48 #include <immintrin.h>
49 
50 /* Tile configuration structure */
51 typedef struct __tile_config {
52  uint8_t palette_id;
53  uint8_t start_row;
54  uint8_t reserved_0[14];
55  uint16_t colsb[16]; /* Columns in bytes for each tile */
56  uint8_t rows[16]; /* Rows for each tile */
57 } __tile_config;
58 
59 /* AMX tile dimensions for our use case */
60 #define AMX_TILE_M 16 /* Rows per tile (matches hardware max) */
61 #define AMX_TILE_N 16 /* Output columns (16 int32 = 64 bytes) */
62 #define AMX_TILE_K 64 /* K dimension (64 int8 = 64 bytes) */
63 
64 /* Tile register assignments for GEMM:
65  * TMM0: A tile (activations)
66  * TMM1: B tile (weights)
67  * TMM2: C tile (accumulator)
68  * TMM3-7: Reserved for larger blocking
69  */
70 #define TILE_A 0
71 #define TILE_B 1
72 #define TILE_C 2
73 
74 /**
75  * @brief Configure AMX tiles for GEMM operation
76  *
77  * Must be called before using tile instructions.
78  * Tiles:
79  * - TILE_A: M x K bytes (activations, INT8)
80  * - TILE_B: K x N bytes (weights, INT8, must be K rows for TDPB*)
81  * - TILE_C: M x (N*4) bytes (accumulator, INT32)
82  */
83 static void configure_tiles_gemm(int M, int N, int K) {
84  __tile_config config = {0};
85 
86  config.palette_id = 1; /* Use palette 1 (standard) */
87 
88  /* Tile A: M rows, K columns (K bytes per row) */
89  int tile_m = (M > AMX_TILE_M) ? AMX_TILE_M : M;
90  int tile_k = (K > AMX_TILE_K) ? AMX_TILE_K : K;
91  int tile_n = (N > AMX_TILE_N) ? AMX_TILE_N : N;
92 
93  /* TILE_A: Input activations (INT8) */
94  config.rows[TILE_A] = tile_m;
95  config.colsb[TILE_A] = tile_k;
96 
97  /* TILE_B: Weights (INT8) - note: for TDPBSSD, B has K rows, N cols */
98  config.rows[TILE_B] = tile_k;
99  config.colsb[TILE_B] = tile_n * 4; /* N int32 outputs = N*4 bytes */
100 
101  /* TILE_C: Accumulator (INT32) */
102  config.rows[TILE_C] = tile_m;
103  config.colsb[TILE_C] = tile_n * 4; /* N int32 outputs = N*4 bytes */
104 
105  _tile_loadconfig(&config);
106 }
107 
108 /**
109  * @brief Release AMX tile configuration
110  */
111 static void release_tiles(void) {
112  _tile_release();
113 }
114 
115 /**
116  * @brief AMX INT8 GEMM: C[M,N] += A[M,K] @ B[K,N]
117  *
118  * Uses TDPBSSD (signed int8 x signed int8 -> int32 accumulate)
119  *
120  * @param A INT8 activations [M, K], row-major
121  * @param B INT8 weights [K, N], column-major (transposed for efficiency)
122  * @param C INT32 accumulator [M, N], row-major
123  * @param M Output rows
124  * @param N Output columns
125  * @param K Inner dimension
126  */
127 void gemm_amx_int8_core(
128  const int8_t *A,
129  const int8_t *B,
130  int32_t *C,
131  int M, int N, int K)
132 {
133  /* Configure tiles for this GEMM size */
134  configure_tiles_gemm(M, N, K);
135 
136  /* Process in tiles */
137  for (int m = 0; m < M; m += AMX_TILE_M) {
138  int tile_m = (m + AMX_TILE_M <= M) ? AMX_TILE_M : (M - m);
139 
140  for (int n = 0; n < N; n += AMX_TILE_N) {
141  int tile_n = (n + AMX_TILE_N <= N) ? AMX_TILE_N : (N - n);
142 
143  /* Zero accumulator tile */
144  _tile_zero(TILE_C);
145 
146  /* Accumulate over K dimension */
147  for (int k = 0; k < K; k += AMX_TILE_K) {
148  int tile_k = (k + AMX_TILE_K <= K) ? AMX_TILE_K : (K - k);
149 
150  /* Load A tile: A[m:m+tile_m, k:k+tile_k] */
151  _tile_loadd(TILE_A, A + m * K + k, K);
152 
153  /* Load B tile: B[k:k+tile_k, n:n+tile_n]
154  * Note: B is stored column-major for efficient AMX access */
155  _tile_loadd(TILE_B, B + k * N + n, N * 4);
156 
157  /* TDPBSSD: C += A @ B (signed int8 dot product) */
158  _tile_dpbssd(TILE_C, TILE_A, TILE_B);
159  }
160 
161  /* Store C tile: C[m:m+tile_m, n:n+tile_n] */
162  _tile_stored(TILE_C, C + m * N + n, N * 4);
163  }
164  }
165 
166  release_tiles();
167 }
168 
169 /**
170  * @brief AMX GEMV for Q4_K x Q8_K: y[M] = W[M,K] @ x[K]
171  *
172  * Adapts the tile-based AMX for vector operations by treating the vector
173  * as a 1-row matrix.
174  *
175  * For decode (single token), M=1, so we batch multiple output rows together
176  * to better utilize AMX tiles.
177  *
178  * @param y Output (FP32), shape [M]
179  * @param W Weights in Q4_K format, shape [M, K]
180  * @param x_q8 Input in Q8_K format, shape [K]
181  * @param M Output dimension
182  * @param K Input dimension (must be multiple of 256)
183  */
184 /* Forward declarations for fallback chain: VNNI → AVX2 → AVX → ref */
185 void gemv_q4_k_q8_k_vnni(float *y, const void *W, const void *x_q8, int M, int K);
186 void gemv_q4_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K);
187 void gemv_q4_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K);
188 void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K);
189 
190 void gemv_q4_k_q8_k_amx(float *y,
191  const void *W,
192  const void *x_q8,
193  int M, int K)
194 {
195  /* AMX is best for Q8_0 x Q8_0 (uniform INT8).
196  * For Q4_K x Q8_K, the per-block scales make AMX less efficient.
197  * Fall back through: VNNI → AVX2 → AVX → ref
198  *
199  * TODO: Implement true AMX path by dequantizing Q4_K to INT8 first
200  */
201 #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
202  gemv_q4_k_q8_k_vnni(y, W, x_q8, M, K);
203 #elif defined(__AVX2__)
204  gemv_q4_k_q8_k_avx2(y, W, x_q8, M, K);
205 #elif defined(__AVX__)
206  gemv_q4_k_q8_k_avx(y, W, x_q8, M, K);
207 #else
208  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
209 #endif
210 }
211 
212 /* NOTE: gemm_nt_q8_0_q8_0_amx is defined in gemm_batch_int8.c */
213 
214 /**
215  * @brief Check if AMX is available at runtime
216  */
217 bool amx_available(void) {
218  /* Check CPUID for AMX support */
219  unsigned int eax, ebx, ecx, edx;
220 
221  /* CPUID leaf 7, subleaf 0 */
222  __asm__ __volatile__(
223  "cpuid"
224  : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx)
225  : "a"(7), "c"(0)
226  );
227 
228  /* AMX-TILE: EDX bit 24 */
229  /* AMX-INT8: EDX bit 25 */
230  /* AMX-BF16: EDX bit 22 */
231  bool has_amx_tile = (edx >> 24) & 1;
232  bool has_amx_int8 = (edx >> 25) & 1;
233 
234  return has_amx_tile && has_amx_int8;
235 }
236 
237 #else /* No AMX support */
238 
239 #include <stdbool.h>
240 
241 /* Fallback declarations - use weak symbols to avoid link errors */
242 void gemv_q4_k_q8_k_vnni(float *y, const void *W, const void *x_q8, int M, int K);
243 void gemv_q4_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K);
244 void gemv_q4_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K);
245 void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K);
246 
247 void gemv_q4_k_q8_k_amx(float *y, const void *W, const void *x_q8, int M, int K) {
248  /* No AMX support - cascade through fallbacks: AVX-512 VNNI → AVX2 → AVX → ref */
249 #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
250  gemv_q4_k_q8_k_vnni(y, W, x_q8, M, K);
251 #elif defined(__AVX2__)
252  gemv_q4_k_q8_k_avx2(y, W, x_q8, M, K);
253 #elif defined(__AVX__)
254  gemv_q4_k_q8_k_avx(y, W, x_q8, M, K);
255 #else
256  gemv_q4_k_q8_k_ref(y, W, x_q8, M, K);
257 #endif
258 }
259 
260 /* NOTE: gemm_nt_q8_0_q8_0_amx is defined in gemm_batch_int8.c */
261 
262 bool amx_available(void) {
263  return false;
264 }
265 
266 #endif /* __AMX_INT8__ */
Quantization block structures for weight-only quantization.
#define AMX_TILE_K
#define AMX_TILE_N
#define AMX_TILE_M
void gemv_q4_k_q8_k_avx2(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_vnni(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_amx(float *y, const void *W, const void *x_q8, int M, int K)
bool amx_available(void)
void gemv_q4_k_q8_k_ref(float *y, const void *W, const void *x_q8, int M, int K)
void gemv_q4_k_q8_k_avx(float *y, const void *W, const void *x_q8, int M, int K)
#define C(color)
Definition: show_config.c:39
const CKBPEConfig * config
Definition: true_bpe.h:171