← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_fused_kernels.c
Go to the documentation of this file.
1 /**
2  * @file gemm_fused_kernels.c
3  * @brief Fused GEMM Kernels with activations
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * GEMM operations fused with activations (ReLU, GELU, SiLU) and SwiGLU.
15  * The key benefit: intermediate results stay in registers, avoiding DRAM
16  * round-trips between operations.
17  *
18  * Supported operations:
19  * - gemm_bias_relu_fused: C = ReLU(A @ B^T + bias)
20  * - gemm_bias_gelu_fused: C = GELU(A @ B^T + bias)
21  * - gemm_bias_silu_fused: C = SiLU(A @ B^T + bias)
22  * - gemm_swiglu_fused: C = SiLU(x @ W_gate) * (x @ W_up)
23  *
24  * All kernels support:
25  * - AVX1 SIMD (256-bit vectors, no FMA)
26  * - OpenMP parallelization
27  * - Scalar fallback
28  */
29 
30 #include "ckernel_engine.h"
31 #include <math.h>
32 
33 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__)
34 #include <immintrin.h>
35 #endif
36 
37 #ifdef _OPENMP
38 #include <omp.h>
39 #endif
40 
41 // =============================================================================
42 // SIMD Helpers
43 // =============================================================================
44 
45 #if defined(__AVX__) || defined(__AVX512F__)
46 // Horizontal sum of 8 floats in __m256 (works for both AVX and AVX-512)
47 static inline float hsum256_ps_fused(__m256 v) {
48  __m128 lo = _mm256_castps256_ps128(v);
49  __m128 hi = _mm256_extractf128_ps(v, 1);
50  __m128 sum128 = _mm_add_ps(lo, hi);
51  __m128 shuf = _mm_movehdup_ps(sum128);
52  __m128 sums = _mm_add_ps(sum128, shuf);
53  shuf = _mm_movehl_ps(shuf, sums);
54  sums = _mm_add_ss(sums, shuf);
55  return _mm_cvtss_f32(sums);
56 }
57 #endif
58 
59 #if defined(__AVX512F__)
60 // Horizontal sum of 16 floats in __m512
61 static inline float hsum512_ps_fused(__m512 v) {
62  __m256 lo = _mm512_castps512_ps256(v);
63  __m256 hi = _mm512_extractf32x8_ps(v, 1);
64  __m256 sum256 = _mm256_add_ps(lo, hi);
65  return hsum256_ps_fused(sum256);
66 }
67 #endif
68 
69 // =============================================================================
70 // Fast activation approximations (scalar)
71 // =============================================================================
72 
73 // GELU approximation: x * sigmoid(1.702 * x) (QuickGELU)
74 static inline float fast_gelu_scalar(float x) {
75  float sx = 1.702f * x;
76  float sig = 1.0f / (1.0f + expf(-sx));
77  return x * sig;
78 }
79 
80 // =============================================================================
81 // GEMM + Bias + ReLU fused
82 // C[i,j] = max(0, sum_k(A[i,k] * B[j,k]) + bias[j])
83 // =============================================================================
84 void gemm_bias_relu_fused(const float *A,
85  const float *B,
86  const float *bias,
87  float *C,
88  int M, int N, int K)
89 {
90 #if defined(__AVX__)
91 #pragma omp parallel for
92  for (int i = 0; i < M; i++) {
93  for (int j = 0; j < N; j++) {
94  __m256 sum_vec = _mm256_setzero_ps();
95  int k;
96  for (k = 0; k <= K - 8; k += 8) {
97  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
98  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
99  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
100  sum_vec = _mm256_add_ps(sum_vec, prod);
101  }
102  float sum = hsum256_ps_fused(sum_vec);
103  for (; k < K; k++) {
104  sum += A[i * K + k] * B[j * K + k];
105  }
106  // Fused: add bias and ReLU while still in register
107  sum += bias[j];
108  C[i * N + j] = sum > 0.0f ? sum : 0.0f;
109  }
110  }
111 #else
112 #pragma omp parallel for
113  for (int i = 0; i < M; i++) {
114  for (int j = 0; j < N; j++) {
115  float sum = 0.0f;
116  for (int k = 0; k < K; k++) {
117  sum += A[i * K + k] * B[j * K + k];
118  }
119  sum += bias[j];
120  C[i * N + j] = sum > 0.0f ? sum : 0.0f;
121  }
122  }
123 #endif
124 }
125 
126 // =============================================================================
127 // GEMM + Bias + GELU fused
128 // C[i,j] = GELU(sum_k(A[i,k] * B[j,k]) + bias[j])
129 // Uses QuickGELU approximation: x * sigmoid(1.702 * x)
130 // =============================================================================
131 void gemm_bias_gelu_fused(const float *A,
132  const float *B,
133  const float *bias,
134  float *C,
135  int M, int N, int K)
136 {
137 #if defined(__AVX__)
138 #pragma omp parallel for
139  for (int i = 0; i < M; i++) {
140  for (int j = 0; j < N; j++) {
141  __m256 sum_vec = _mm256_setzero_ps();
142  int k;
143  for (k = 0; k <= K - 8; k += 8) {
144  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
145  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
146  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
147  sum_vec = _mm256_add_ps(sum_vec, prod);
148  }
149  float sum = hsum256_ps_fused(sum_vec);
150  for (; k < K; k++) {
151  sum += A[i * K + k] * B[j * K + k];
152  }
153  sum += bias[j];
154  C[i * N + j] = fast_gelu_scalar(sum);
155  }
156  }
157 #else
158 #pragma omp parallel for
159  for (int i = 0; i < M; i++) {
160  for (int j = 0; j < N; j++) {
161  float sum = 0.0f;
162  for (int k = 0; k < K; k++) {
163  sum += A[i * K + k] * B[j * K + k];
164  }
165  sum += bias[j];
166  C[i * N + j] = fast_gelu_scalar(sum);
167  }
168  }
169 #endif
170 }
171 
172 // =============================================================================
173 // GEMM + Bias + SiLU/Swish fused
174 // C[i,j] = SiLU(sum_k(A[i,k] * B[j,k]) + bias[j])
175 // SiLU(x) = x * sigmoid(x)
176 // =============================================================================
177 void gemm_bias_silu_fused(const float *A,
178  const float *B,
179  const float *bias,
180  float *C,
181  int M, int N, int K)
182 {
183 #if defined(__AVX__)
184 #pragma omp parallel for
185  for (int i = 0; i < M; i++) {
186  for (int j = 0; j < N; j++) {
187  __m256 sum_vec = _mm256_setzero_ps();
188  int k;
189  for (k = 0; k <= K - 8; k += 8) {
190  __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
191  __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
192  __m256 prod = _mm256_mul_ps(a_vec, b_vec);
193  sum_vec = _mm256_add_ps(sum_vec, prod);
194  }
195  float sum = hsum256_ps_fused(sum_vec);
196  for (; k < K; k++) {
197  sum += A[i * K + k] * B[j * K + k];
198  }
199  sum += bias[j];
200  // SiLU: x * sigmoid(x)
201  float sig = 1.0f / (1.0f + expf(-sum));
202  C[i * N + j] = sum * sig;
203  }
204  }
205 #else
206 #pragma omp parallel for
207  for (int i = 0; i < M; i++) {
208  for (int j = 0; j < N; j++) {
209  float sum = 0.0f;
210  for (int k = 0; k < K; k++) {
211  sum += A[i * K + k] * B[j * K + k];
212  }
213  sum += bias[j];
214  float sig = 1.0f / (1.0f + expf(-sum));
215  C[i * N + j] = sum * sig;
216  }
217  }
218 #endif
219 }
220 
221 // =============================================================================
222 // GEMM + SwiGLU Fused (LLaMA/SmolLM style MLP)
223 //
224 // Computes: output = SiLU(x @ W_gate + b_gate) * (x @ W_up + b_up)
225 //
226 // This fuses TWO GEMMs + SwiGLU activation into one pass:
227 // - gate = x @ W_gate + b_gate (GEMM 1)
228 // - up = x @ W_up + b_up (GEMM 2)
229 // - out = SiLU(gate) * up (SwiGLU)
230 //
231 // Layout:
232 // x: [M, K] input activations
233 // W_gate: [N, K] gate projection weights (transposed)
234 // W_up: [N, K] up projection weights (transposed)
235 // b_gate: [N] gate bias (can be NULL)
236 // b_up: [N] up bias (can be NULL)
237 // output: [M, N] result
238 //
239 // The key insight: gate and up values stay in registers, never written to DRAM
240 // =============================================================================
241 void gemm_swiglu_fused(const float *x,
242  const float *W_gate,
243  const float *W_up,
244  const float *b_gate,
245  const float *b_up,
246  float *output,
247  int M, int N, int K)
248 {
249 #if defined(__AVX__)
250 #pragma omp parallel for
251  for (int i = 0; i < M; i++) {
252  const float *x_row = &x[i * K];
253  float *out_row = &output[i * N];
254 
255  for (int j = 0; j < N; j++) {
256  const float *w_gate_row = &W_gate[j * K];
257  const float *w_up_row = &W_up[j * K];
258 
259  // Compute both dot products in parallel using SIMD
260  __m256 gate_vec = _mm256_setzero_ps();
261  __m256 up_vec = _mm256_setzero_ps();
262 
263  int k;
264  for (k = 0; k <= K - 8; k += 8) {
265  __m256 x_vec = _mm256_loadu_ps(&x_row[k]);
266  __m256 wg_vec = _mm256_loadu_ps(&w_gate_row[k]);
267  __m256 wu_vec = _mm256_loadu_ps(&w_up_row[k]);
268 
269  // gate += x * W_gate
270  gate_vec = _mm256_add_ps(gate_vec, _mm256_mul_ps(x_vec, wg_vec));
271  // up += x * W_up
272  up_vec = _mm256_add_ps(up_vec, _mm256_mul_ps(x_vec, wu_vec));
273  }
274 
275  // Horizontal sum
276  float gate = hsum256_ps_fused(gate_vec);
277  float up = hsum256_ps_fused(up_vec);
278 
279  // Scalar remainder
280  for (; k < K; k++) {
281  gate += x_row[k] * w_gate_row[k];
282  up += x_row[k] * w_up_row[k];
283  }
284 
285  // Add biases
286  if (b_gate) gate += b_gate[j];
287  if (b_up) up += b_up[j];
288 
289  // SwiGLU: SiLU(gate) * up = gate * sigmoid(gate) * up
290  float sig = 1.0f / (1.0f + expf(-gate));
291  out_row[j] = gate * sig * up;
292  }
293  }
294 #else
295  // Scalar fallback
296 #pragma omp parallel for
297  for (int i = 0; i < M; i++) {
298  for (int j = 0; j < N; j++) {
299  float gate = 0.0f;
300  float up = 0.0f;
301 
302  for (int k = 0; k < K; k++) {
303  gate += x[i * K + k] * W_gate[j * K + k];
304  up += x[i * K + k] * W_up[j * K + k];
305  }
306 
307  if (b_gate) gate += b_gate[j];
308  if (b_up) up += b_up[j];
309 
310  // SwiGLU: SiLU(gate) * up
311  float sig = 1.0f / (1.0f + expf(-gate));
312  output[i * N + j] = gate * sig * up;
313  }
314  }
315 #endif
316 }
void gemm_bias_silu_fused(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemm_swiglu_fused(const float *x, const float *W_gate, const float *W_up, const float *b_gate, const float *b_up, float *output, int M, int N, int K)
void gemm_bias_relu_fused(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
static float fast_gelu_scalar(float x)
void gemm_bias_gelu_fused(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
static float hsum256_ps_fused(__m256 v)
#define C(color)
Definition: show_config.c:39