Optimized BF16 GEMM Kernels for AVX-512. More...
Go to the source code of this file.
Macros | |
| #define | BLK_K 256 |
| #define | BLK_M 64 |
| #define | BLK_N 64 |
Functions | |
| __attribute__ ((unused)) | |
| static int | ck_min_i (int a, int b) |
| void | gemm_bf16_fp32out (const uint16_t *A, const uint16_t *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_blocked_serial_bf16 (const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K) |
| void | gemm_nn_bf16 (const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K) |
| void | gemm_tn_bf16 (const uint16_t *A, const uint16_t *B, const uint16_t *bias, uint16_t *C, int M, int N, int K) |
Optimized BF16 GEMM Kernels for AVX-512.
After changes: make test && make llamacpp-parity-full
Layout: A: [M x K] row-major (BF16) B: [N x K] row-major, stored as [out x in] (BF16) C: [M x N] row-major (BF16 or FP32)
Key optimizations:
Definition in file gemm_kernels_bf16.c.
| #define BLK_K 256 |
Definition at line 43 of file gemm_kernels_bf16.c.
| #define BLK_M 64 |
Definition at line 41 of file gemm_kernels_bf16.c.
| #define BLK_N 64 |
Definition at line 42 of file gemm_kernels_bf16.c.
| __attribute__ | ( | (unused) | ) |
Definition at line 51 of file gemm_kernels_bf16.c.
References bf16_to_float(), C, and float_to_bf16().
|
inlinestatic |
Definition at line 45 of file gemm_kernels_bf16.c.
| void gemm_bf16_fp32out | ( | const uint16_t * | A, |
| const uint16_t * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 301 of file gemm_kernels_bf16.c.
References bf16_to_float(), and C.
Referenced by mlp_token_parallel_bf16(), and mlp_token_parallel_bf16_fp32act().
| void gemm_blocked_serial_bf16 | ( | const uint16_t * | A, |
| const uint16_t * | B, | ||
| const uint16_t * | bias, | ||
| uint16_t * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
| void gemm_nn_bf16 | ( | const uint16_t * | A, |
| const uint16_t * | B, | ||
| const uint16_t * | bias, | ||
| uint16_t * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 360 of file gemm_kernels_bf16.c.
References bf16_to_float(), C, and float_to_bf16().
| void gemm_tn_bf16 | ( | const uint16_t * | A, |
| const uint16_t * | B, | ||
| const uint16_t * | bias, | ||
| uint16_t * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 427 of file gemm_kernels_bf16.c.
References bf16_to_float(), C, and float_to_bf16().