← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mlp_kernels.c
Go to the documentation of this file.
1 /**
2  * @file mlp_kernels.c
3  * @brief MLP (feed-forward) kernels with SIMD (SSE/AVX/AVX512)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * LEGACY EXCEPTION: This file contains OpenMP for backward compatibility.
15  * New kernels should NOT use OpenMP internally.
16  *
17  * MLP: out = FC2(GELU(FC1(x)))
18  */
19 
20 #include "ckernel_engine.h"
21 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__)
22 #include <immintrin.h>
23 #endif
24 #include <omp.h>
25 #include <stdlib.h>
26 
27 /* Forward MLP kernel (FC1 -> GELU -> FC2) adapted from C-Transformer's */
28 // mlp_token_parallel but expressed in a model-agnostic form. We keep the
29 // familiar name `mlp_token_parallel` for reuse during decode/inference.
30 //
31 // Shapes:
32 // input: [T × D] (row-major, stride = aligned_dim)
33 // W_fc1: [4D × D] (row-major, stored as [out × in])
34 // b_fc1: [4D]
35 // W_fc2: [D × 4D]
36 // b_fc2: [D]
37 // fc1_output: [T × 4D] (workspace, also becomes GELU input/output)
38 // output: [T × D]
39 //
40 // D is typically `aligned_dim` in your transformer; pass that value here.
41 void mlp_token_parallel(const float *input,
42  const float *W_fc1,
43  const float *b_fc1,
44  const float *W_fc2,
45  const float *b_fc2,
46  float *fc1_output,
47  float *output,
48  int T,
49  int aligned_dim,
50  int num_threads)
51 {
52  int D = aligned_dim;
53  int fourD = 4 * D;
54 
55  // FC1: [T × D] · [D × 4D] -> [T × 4D]
56  // Our GEMM layout: A[M×K], B[N×K], so B is [4D × D].
57  gemm_blocked_serial(input, W_fc1, b_fc1,
58  fc1_output,
59  T, // M
60  fourD, // N
61  D); // K
62 
63  // GELU in-place on FC1 output
64  gelu_fast_inplace(fc1_output, (size_t)T * (size_t)fourD);
65 
66  // FC2: [T × 4D] · [4D × D] -> [T × D]
67  gemm_blocked_serial(fc1_output, W_fc2, b_fc2,
68  output,
69  T, // M
70  D, // N
71  fourD); // K
72 }
73 
74 // Exact version of MLP forward using scalar GELU with standard library tanhf.
75 // Slower but provides maximum accuracy. Used for correctness testing.
76 void mlp_token_parallel_exact(const float *input,
77  const float *W_fc1,
78  const float *b_fc1,
79  const float *W_fc2,
80  const float *b_fc2,
81  float *fc1_output,
82  float *output,
83  int T,
84  int aligned_dim,
85  int num_threads)
86 {
87  (void)num_threads;
88  int D = aligned_dim;
89  int fourD = 4 * D;
90 
91  // FC1: [T × D] · [D × 4D] -> [T × 4D]
92  gemm_blocked_serial(input, W_fc1, b_fc1,
93  fc1_output,
94  T, // M
95  fourD, // N
96  D); // K
97 
98  // Exact GELU using standard library tanhf
99  gelu_exact_inplace(fc1_output, (size_t)T * (size_t)fourD);
100 
101  // FC2: [T × 4D] · [4D × D] -> [T × D]
102  gemm_blocked_serial(fc1_output, W_fc2, b_fc2,
103  output,
104  T, // M
105  D, // N
106  fourD); // K
107 }
108 
109 // Generic FC2 backward kernel adapted from C-Transformer's backward_fc2_feature_parallel.
110 // Now uses shared GEMM kernels for d_input and d_W computation.
111 // Shapes:
112 // d_output: [T × aligned_out]
113 // fc2_input: [T × aligned_in]
114 // W_fc2: [aligned_out × aligned_in] (row-major)
115 // d_input: [T × aligned_in]
116 // d_W_fc2: [aligned_out × aligned_in] (accumulated)
117 // d_b_fc2: [aligned_out] (accumulated)
118 void fc2_backward_kernel(const float *d_output,
119  const float *fc2_input,
120  const float *W_fc2,
121  float *d_input,
122  float *d_W_fc2,
123  float *d_b_fc2,
124  int T,
125  int aligned_in,
126  int aligned_out,
127  int num_threads)
128 {
129  (void)num_threads; // Threading handled by GEMM kernels
130 
131  // 1. d_input[T, in] = d_output[T, out] @ W[out, in]
132  // Using gemm_nn: C[M,N] = A[M,K] @ B[K,N]
133  // A = d_output [T, out], B = W [out, in], C = d_input [T, in]
134  // M = T, N = aligned_in, K = aligned_out
135  gemm_nn_avx512(d_output, W_fc2, NULL, d_input,
136  T, aligned_in, aligned_out);
137 
138  // 2. d_W[out, in] = d_output[T, out].T @ fc2_input[T, in]
139  // Using gemm_tn: C[M,N] = A[K,M].T @ B[K,N]
140  // A = d_output [T, out] (stored as [K=T, M=out]), B = fc2_input [T, in]
141  // C = d_W [out, in], M = aligned_out, N = aligned_in, K = T
142  // Note: gemm_tn overwrites, so we need to save and add if accumulating
143  // For now, assume d_W starts zeroed (gradient accumulation handled at higher level)
144  gemm_tn_avx512(d_output, fc2_input, NULL, d_W_fc2,
145  aligned_out, aligned_in, T);
146 
147  // 3. d_b_fc2 = sum_over_T(d_output)
148 #pragma omp parallel for schedule(static)
149  for (int out_idx = 0; out_idx < aligned_out; ++out_idx) {
150  float bias_grad = 0.0f;
151  for (int t = 0; t < T; ++t) {
152  bias_grad += d_output[(size_t)t * aligned_out + out_idx];
153  }
154  d_b_fc2[out_idx] += bias_grad;
155  }
156 }
157 
158 // Generic FC1 backward kernel adapted from C-Transformer's backward_fc1_feature_parallel.
159 // Now uses shared GEMM kernels for d_input and d_W computation.
160 // Shapes:
161 // d_output: [T × aligned_out]
162 // fc1_input: [T × aligned_in]
163 // W_fc1: [aligned_out × aligned_in] (row-major)
164 // d_input: [T × aligned_in]
165 // d_W_fc1: [aligned_out × aligned_in] (accumulated)
166 // d_b_fc1: [aligned_out] (accumulated)
167 void fc1_backward_kernel(const float *d_output,
168  const float *fc1_input,
169  const float *W_fc1,
170  float *d_input,
171  float *d_W_fc1,
172  float *d_b_fc1,
173  int T,
174  int aligned_in,
175  int aligned_out,
176  int num_threads)
177 {
178  (void)num_threads; // Threading handled by GEMM kernels
179 
180  // 1. d_input[T, in] = d_output[T, out] @ W[out, in]
181  // Using gemm_nn: C[M,N] = A[M,K] @ B[K,N]
182  // A = d_output [T, out], B = W [out, in], C = d_input [T, in]
183  // M = T, N = aligned_in, K = aligned_out
184  gemm_nn_avx512(d_output, W_fc1, NULL, d_input,
185  T, aligned_in, aligned_out);
186 
187  // 2. d_W[out, in] = d_output[T, out].T @ fc1_input[T, in]
188  // Using gemm_tn: C[M,N] = A[K,M].T @ B[K,N]
189  // A = d_output [T, out] (stored as [K=T, M=out]), B = fc1_input [T, in]
190  // C = d_W [out, in], M = aligned_out, N = aligned_in, K = T
191  gemm_tn_avx512(d_output, fc1_input, NULL, d_W_fc1,
192  aligned_out, aligned_in, T);
193 
194  // 3. d_b_fc1 = sum_over_T(d_output)
195 #pragma omp parallel for schedule(static)
196  for (int out_idx = 0; out_idx < aligned_out; ++out_idx) {
197  float bias_grad = 0.0f;
198  for (int t = 0; t < T; ++t) {
199  bias_grad += d_output[(size_t)t * aligned_out + out_idx];
200  }
201  d_b_fc1[out_idx] += bias_grad;
202  }
203 }
void gelu_exact_inplace(float *data, size_t n)
Definition: gelu_kernels.c:446
void gemm_nn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:339
void gelu_fast_inplace(float *data, size_t n)
Definition: gelu_kernels.c:132
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:661
void gemm_tn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:521
void mlp_token_parallel(const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads)
Definition: mlp_kernels.c:41
void fc1_backward_kernel(const float *d_output, const float *fc1_input, const float *W_fc1, float *d_input, float *d_W_fc1, float *d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads)
Definition: mlp_kernels.c:167
void mlp_token_parallel_exact(const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads)
Definition: mlp_kernels.c:76
void fc2_backward_kernel(const float *d_output, const float *fc2_input, const float *W_fc2, float *d_input, float *d_W_fc2, float *d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads)
Definition: mlp_kernels.c:118