Fused GEMM Kernels with activations. More...
Go to the source code of this file.
Functions | |
| static float | fast_gelu_scalar (float x) |
| void | gemm_bias_gelu_fused (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_bias_relu_fused (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_bias_silu_fused (const float *A, const float *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_swiglu_fused (const float *x, const float *W_gate, const float *W_up, const float *b_gate, const float *b_up, float *output, int M, int N, int K) |
Fused GEMM Kernels with activations.
After changes: make test && make llamacpp-parity-full
GEMM operations fused with activations (ReLU, GELU, SiLU) and SwiGLU. The key benefit: intermediate results stay in registers, avoiding DRAM round-trips between operations.
Supported operations:
All kernels support:
Definition in file gemm_fused_kernels.c.
|
inlinestatic |
Definition at line 74 of file gemm_fused_kernels.c.
Referenced by gemm_bias_gelu_fused().
| void gemm_bias_gelu_fused | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 131 of file gemm_fused_kernels.c.
References C, fast_gelu_scalar(), and hsum256_ps_fused().
| void gemm_bias_relu_fused | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
| void gemm_bias_silu_fused | ( | const float * | A, |
| const float * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
| void gemm_swiglu_fused | ( | const float * | x, |
| const float * | W_gate, | ||
| const float * | W_up, | ||
| const float * | b_gate, | ||
| const float * | b_up, | ||
| float * | output, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 241 of file gemm_fused_kernels.c.
References hsum256_ps_fused().
Referenced by ck_mlp_swiglu_forward_fused_token().