Batch GEMM kernels for quantized weights with INT8 activations. More...
#include <stdint.h>#include <stddef.h>#include <string.h>#include <math.h>#include "ckernel_quant.h"Go to the source code of this file.
Macros | |
| #define | AMX_TILE_K 64 |
| #define | AMX_TILE_M 16 |
| #define | AMX_TILE_N 16 |
| #define | HAS_AMX 0 |
| #define | QK5_0 32 /* Q5_0: 32 weights per block */ |
| #define | QK8_0 32 /* Q8_0: 32 weights per block */ |
Functions | |
| const char * | gemm_batch_int8_impl_name (void) |
| Get the best implementation name for logging/debugging. More... | |
| void | gemm_nt_q5_0_q8_0_ref (const void *A, const void *B, float *C, int M, int N, int K) |
| Dispatcher for gemm_nt_q8_0_q8_0. More... | |
| void | gemm_nt_q8_0_q8_0 (const void *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| gemm_nt_q8_0_q8_0 with optional bias (matches header signature) More... | |
| void | gemm_nt_q8_0_q8_0_ref (const void *A, const void *B, float *C, int M, int N, int K) |
| Scalar reference: gemm_nt_q8_0_q8_0. More... | |
Batch GEMM kernels for quantized weights with INT8 activations.
After changes: make test && make llamacpp-parity-full
Implements batch matrix multiplication where:
Operation: C[M,N] = A[M,K] @ B[N,K]^T (B is transposed/row-major weights)
Instruction Set Implementations:
Design Philosophy:
Definition in file gemm_batch_int8.c.
| #define AMX_TILE_K 64 |
Definition at line 65 of file gemm_batch_int8.c.
| #define AMX_TILE_M 16 |
Definition at line 63 of file gemm_batch_int8.c.
| #define AMX_TILE_N 16 |
Definition at line 64 of file gemm_batch_int8.c.
| #define HAS_AMX 0 |
Definition at line 52 of file gemm_batch_int8.c.
| #define QK5_0 32 /* Q5_0: 32 weights per block */ |
Definition at line 60 of file gemm_batch_int8.c.
| #define QK8_0 32 /* Q8_0: 32 weights per block */ |
Definition at line 59 of file gemm_batch_int8.c.
| const char* gemm_batch_int8_impl_name | ( | void | ) |
Get the best implementation name for logging/debugging.
Definition at line 553 of file gemm_batch_int8.c.
| void gemm_nt_q5_0_q8_0_ref | ( | const void * | A, |
| const void * | B, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Dispatcher for gemm_nt_q8_0_q8_0.
Selects the best available implementation at runtime.
Scalar reference: gemm_nt_q5_0_q8_0
Q5_0 weight reconstruction: weight[j] = d * ((qs_nibble | (qh_bit << 4)) - 16)
For j in 0..15: use low nibble + qh bit j For j in 16..31: use high nibble + qh bit (j+16) -> actually bit (j) for j=16..31
| A | Input activations [M, K] in Q8_0 format |
| B | Weight matrix [N, K] in Q5_0 format |
| C | Output matrix [M, N] in FP32 |
| M | Number of tokens (batch size) |
| N | Number of output features |
| K | Number of input features (must be multiple of 32) |
Definition at line 391 of file gemm_batch_int8.c.
References C, CK_FP16_TO_FP32, QK5_0, and block_q8_0::qs.
| void gemm_nt_q8_0_q8_0 | ( | const void * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
C[m,n] = A[m,K] @ B[n,K]^T + bias[n]
Definition at line 582 of file gemm_batch_int8.c.
References C, and gemm_nt_q8_0_q8_0_ref().
Referenced by ck_test_gemm_q8_0(), gemm_nt_q8_0_dispatch(), and gemm_nt_q8_0_mlp_dispatch().
| void gemm_nt_q8_0_q8_0_ref | ( | const void * | A, |
| const void * | B, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Scalar reference: gemm_nt_q8_0_q8_0.
C[m,n] = sum_k( dequant(A[m,k]) * dequant(B[n,k]) ) = sum_blocks( d_a * d_b * sum_j(a_qs[j] * b_qs[j]) )
| A | Input activations [M, K] in Q8_0 format |
| B | Weight matrix [N, K] in Q8_0 format (row-major, each row is one output) |
| C | Output matrix [M, N] in FP32 |
| M | Number of tokens (batch size) |
| N | Number of output features (rows in B) |
| K | Number of input features (must be multiple of 32) |
Definition at line 87 of file gemm_batch_int8.c.
References C, CK_FP16_TO_FP32, QK8_0, and block_q8_0::qs.
Referenced by gemm_nt_q8_0_q8_0().