← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_fused_kernels.c File Reference

Fused GEMM Kernels with activations. More...

#include "ckernel_engine.h"
#include <math.h>

Go to the source code of this file.

Functions

static float fast_gelu_scalar (float x)
 
void gemm_bias_gelu_fused (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_bias_relu_fused (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_bias_silu_fused (const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
 
void gemm_swiglu_fused (const float *x, const float *W_gate, const float *W_up, const float *b_gate, const float *b_up, float *output, int M, int N, int K)
 

Detailed Description

Fused GEMM Kernels with activations.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

GEMM operations fused with activations (ReLU, GELU, SiLU) and SwiGLU. The key benefit: intermediate results stay in registers, avoiding DRAM round-trips between operations.

Supported operations:

  • gemm_bias_relu_fused: C = ReLU(A @ B^T + bias)
  • gemm_bias_gelu_fused: C = GELU(A @ B^T + bias)
  • gemm_bias_silu_fused: C = SiLU(A @ B^T + bias)
  • gemm_swiglu_fused: C = SiLU(x @ W_gate) * (x @ W_up)

All kernels support:

  • AVX1 SIMD (256-bit vectors, no FMA)
  • OpenMP parallelization
  • Scalar fallback

Definition in file gemm_fused_kernels.c.

Function Documentation

◆ fast_gelu_scalar()

static float fast_gelu_scalar ( float  x)
inlinestatic

Definition at line 74 of file gemm_fused_kernels.c.

74  {
75  float sx = 1.702f * x;
76  float sig = 1.0f / (1.0f + expf(-sx));
77  return x * sig;
78 }

Referenced by gemm_bias_gelu_fused().

◆ gemm_bias_gelu_fused()

void gemm_bias_gelu_fused ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 131 of file gemm_fused_kernels.c.

136 {
137 #if defined(__AVX__)
138 #pragma omp parallel for
139  for (int i = 0; i < M; i++) {
140  for (int j = 0; j < N; j++) {
141  __m256 sum_vec = _mm256_setzero_ps();
142  int k;
143  for (k = 0; k <= K - 8; k += 8) {
144  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
145  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
146  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
147  sum_vec = _mm256_add_ps(sum_vec, prod);
148  }
149  float sum = hsum256_ps_fused(sum_vec);
150  for (; k < K; k++) {
151  sum += A[i * K + k] * B[j * K + k];
152  }
153  sum += bias[j];
154  C[i * N + j] = fast_gelu_scalar(sum);
155  }
156  }
157 #else
158 #pragma omp parallel for
159  for (int i = 0; i < M; i++) {
160  for (int j = 0; j < N; j++) {
161  float sum = 0.0f;
162  for (int k = 0; k < K; k++) {
163  sum += A[i * K + k] * B[j * K + k];
164  }
165  sum += bias[j];
166  C[i * N + j] = fast_gelu_scalar(sum);
167  }
168  }
169 #endif
170 }
static float fast_gelu_scalar(float x)
static float hsum256_ps_fused(__m256 v)
#define C(color)
Definition: show_config.c:39

References C, fast_gelu_scalar(), and hsum256_ps_fused().

◆ gemm_bias_relu_fused()

void gemm_bias_relu_fused ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 84 of file gemm_fused_kernels.c.

89 {
90 #if defined(__AVX__)
91 #pragma omp parallel for
92  for (int i = 0; i < M; i++) {
93  for (int j = 0; j < N; j++) {
94  __m256 sum_vec = _mm256_setzero_ps();
95  int k;
96  for (k = 0; k <= K - 8; k += 8) {
97  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
98  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
99  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
100  sum_vec = _mm256_add_ps(sum_vec, prod);
101  }
102  float sum = hsum256_ps_fused(sum_vec);
103  for (; k < K; k++) {
104  sum += A[i * K + k] * B[j * K + k];
105  }
106  // Fused: add bias and ReLU while still in register
107  sum += bias[j];
108  C[i * N + j] = sum > 0.0f ? sum : 0.0f;
109  }
110  }
111 #else
112 #pragma omp parallel for
113  for (int i = 0; i < M; i++) {
114  for (int j = 0; j < N; j++) {
115  float sum = 0.0f;
116  for (int k = 0; k < K; k++) {
117  sum += A[i * K + k] * B[j * K + k];
118  }
119  sum += bias[j];
120  C[i * N + j] = sum > 0.0f ? sum : 0.0f;
121  }
122  }
123 #endif
124 }

References C, and hsum256_ps_fused().

◆ gemm_bias_silu_fused()

void gemm_bias_silu_fused ( const float *  A,
const float *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K 
)

Definition at line 177 of file gemm_fused_kernels.c.

182 {
183 #if defined(__AVX__)
184 #pragma omp parallel for
185  for (int i = 0; i < M; i++) {
186  for (int j = 0; j < N; j++) {
187  __m256 sum_vec = _mm256_setzero_ps();
188  int k;
189  for (k = 0; k <= K - 8; k += 8) {
190  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
191  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
192  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
193  sum_vec = _mm256_add_ps(sum_vec, prod);
194  }
195  float sum = hsum256_ps_fused(sum_vec);
196  for (; k < K; k++) {
197  sum += A[i * K + k] * B[j * K + k];
198  }
199  sum += bias[j];
200  // SiLU: x * sigmoid(x)
201  float sig = 1.0f / (1.0f + expf(-sum));
202  C[i * N + j] = sum * sig;
203  }
204  }
205 #else
206 #pragma omp parallel for
207  for (int i = 0; i < M; i++) {
208  for (int j = 0; j < N; j++) {
209  float sum = 0.0f;
210  for (int k = 0; k < K; k++) {
211  sum += A[i * K + k] * B[j * K + k];
212  }
213  sum += bias[j];
214  float sig = 1.0f / (1.0f + expf(-sum));
215  C[i * N + j] = sum * sig;
216  }
217  }
218 #endif
219 }

References C, and hsum256_ps_fused().

◆ gemm_swiglu_fused()

void gemm_swiglu_fused ( const float *  x,
const float *  W_gate,
const float *  W_up,
const float *  b_gate,
const float *  b_up,
float *  output,
int  M,
int  N,
int  K 
)

Definition at line 241 of file gemm_fused_kernels.c.

248 {
249 #if defined(__AVX__)
250 #pragma omp parallel for
251  for (int i = 0; i < M; i++) {
252  const float *x_row = &x[i * K];
253  float *out_row = &output[i * N];
254 
255  for (int j = 0; j < N; j++) {
256  const float *w_gate_row = &W_gate[j * K];
257  const float *w_up_row = &W_up[j * K];
258 
259  // Compute both dot products in parallel using SIMD
260  __m256 gate_vec = _mm256_setzero_ps();
261  __m256 up_vec = _mm256_setzero_ps();
262 
263  int k;
264  for (k = 0; k <= K - 8; k += 8) {
265  __m256 x_vec = _mm256_loadu_ps(&x_row[k]);
266  __m256 wg_vec = _mm256_loadu_ps(&w_gate_row[k]);
267  __m256 wu_vec = _mm256_loadu_ps(&w_up_row[k]);
268 
269  // gate += x * W_gate
270  gate_vec = _mm256_add_ps(gate_vec, _mm256_mul_ps(x_vec, wg_vec));
271  // up += x * W_up
272  up_vec = _mm256_add_ps(up_vec, _mm256_mul_ps(x_vec, wu_vec));
273  }
274 
275  // Horizontal sum
276  float gate = hsum256_ps_fused(gate_vec);
277  float up = hsum256_ps_fused(up_vec);
278 
279  // Scalar remainder
280  for (; k < K; k++) {
281  gate += x_row[k] * w_gate_row[k];
282  up += x_row[k] * w_up_row[k];
283  }
284 
285  // Add biases
286  if (b_gate) gate += b_gate[j];
287  if (b_up) up += b_up[j];
288 
289  // SwiGLU: SiLU(gate) * up = gate * sigmoid(gate) * up
290  float sig = 1.0f / (1.0f + expf(-gate));
291  out_row[j] = gate * sig * up;
292  }
293  }
294 #else
295  // Scalar fallback
296 #pragma omp parallel for
297  for (int i = 0; i < M; i++) {
298  for (int j = 0; j < N; j++) {
299  float gate = 0.0f;
300  float up = 0.0f;
301 
302  for (int k = 0; k < K; k++) {
303  gate += x[i * K + k] * W_gate[j * K + k];
304  up += x[i * K + k] * W_up[j * K + k];
305  }
306 
307  if (b_gate) gate += b_gate[j];
308  if (b_up) up += b_up[j];
309 
310  // SwiGLU: SiLU(gate) * up
311  float sig = 1.0f / (1.0f + expf(-gate));
312  output[i * N + j] = gate * sig * up;
313  }
314  }
315 #endif
316 }

References hsum256_ps_fused().

Referenced by ck_mlp_swiglu_forward_fused_token().