General matrix multiply (GEMM) kernels with SIMD (SSE/AVX/AVX512) More...
Go to the source code of this file.
Functions | |
| static void | ck_gemm_add_bias (float *C, const float *bias, int M, int N) |
| static int | ck_min (int a, int b) |
| void | gemm_avx512_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_blocked_serial (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_fine_grained_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_naive_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| static void | gemm_naive_serial_double (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nn_avx512 (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nn_blocked (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_nn_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| static void | gemm_nn_serial_double (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| static void | gemm_nt_matvec_parallel (const float *A, const float *B, const float *bias, float *C, int N, int K) |
| void | gemm_tn_avx512 (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_tn_blocked (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_tn_parallel (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| static void | gemm_tn_serial_double (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
General matrix multiply (GEMM) kernels with SIMD (SSE/AVX/AVX512)
After changes: make test && make llamacpp-parity-full
LEGACY EXCEPTION: This file contains OpenMP for backward compatibility. New kernels should NOT use OpenMP internally.
GEMM: C = alpha * A @ B + beta * C (with optional bias)
Definition in file gemm_kernels.c.
|
inlinestatic |
|
inlinestatic |
Definition at line 26 of file gemm_kernels.c.
Referenced by gemm_blocked_serial(), gemm_fine_grained_parallel(), gemm_nn_blocked(), and gemm_tn_blocked().
| void gemm_avx512_parallel | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 149 of file gemm_kernels.c.
References C, ck_strict_parity_enabled(), gemm_naive_parallel(), and gemm_naive_serial_double().
| void gemm_blocked_serial | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 661 of file gemm_kernels.c.
References C, ck_gemm_add_bias(), ck_get_num_threads(), ck_min(), ck_strict_parity_enabled(), gemm_microkernel(), gemm_naive_serial_double(), and gemm_nt_matvec_parallel().
Referenced by ck_attention_project_head_major(), ck_gemm_nt_quant(), ck_mlp_swiglu_forward(), ck_mlp_swiglu_forward_fused_token(), ck_qkv_project_head_major(), ck_qkv_project_head_major_token(), mlp_token_parallel(), and mlp_token_parallel_exact().
| void gemm_fine_grained_parallel | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 205 of file gemm_kernels.c.
References C, ck_min(), ck_strict_parity_enabled(), gemm_naive_parallel(), and gemm_naive_serial_double().
| void gemm_naive_parallel | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 125 of file gemm_kernels.c.
References C, ck_strict_parity_enabled(), and gemm_naive_serial_double().
Referenced by ck_attention_project_head_major_ref(), ck_mlp_swiglu_forward_ref(), ck_qkv_project_head_major_ref(), gemm_avx512_parallel(), and gemm_fine_grained_parallel().
|
static |
Definition at line 107 of file gemm_kernels.c.
References C.
Referenced by gemm_avx512_parallel(), gemm_blocked_serial(), gemm_fine_grained_parallel(), and gemm_naive_parallel().
| void gemm_nn_avx512 | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 339 of file gemm_kernels.c.
References C, ck_strict_parity_enabled(), gemm_nn_parallel(), and gemm_nn_serial_double().
Referenced by fc1_backward_kernel(), and fc2_backward_kernel().
| void gemm_nn_blocked | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 402 of file gemm_kernels.c.
References C, ck_min(), ck_strict_parity_enabled(), and gemm_nn_serial_double().
| void gemm_nn_parallel | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 317 of file gemm_kernels.c.
References C, ck_strict_parity_enabled(), and gemm_nn_serial_double().
Referenced by gemm_nn_avx512().
|
static |
Definition at line 300 of file gemm_kernels.c.
References C.
Referenced by gemm_nn_avx512(), gemm_nn_blocked(), and gemm_nn_parallel().
|
static |
| void gemm_tn_avx512 | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 521 of file gemm_kernels.c.
References C, ck_strict_parity_enabled(), gemm_tn_parallel(), and gemm_tn_serial_double().
Referenced by fc1_backward_kernel(), and fc2_backward_kernel().
| void gemm_tn_blocked | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 581 of file gemm_kernels.c.
References C, ck_min(), ck_strict_parity_enabled(), and gemm_tn_serial_double().
| void gemm_tn_parallel | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 499 of file gemm_kernels.c.
References C, ck_strict_parity_enabled(), and gemm_tn_serial_double().
Referenced by gemm_tn_avx512().
|
static |
Definition at line 481 of file gemm_kernels.c.
References C.
Referenced by gemm_tn_avx512(), gemm_tn_blocked(), and gemm_tn_parallel().