← Back to C-Kernel-Engine Docs Doxygen Source Documentation
mlp_kernels.c File Reference

MLP (feed-forward) kernels with SIMD (SSE/AVX/AVX512) More...

#include "ckernel_engine.h"
#include <omp.h>
#include <stdlib.h>

Go to the source code of this file.

Functions

void fc1_backward_kernel (const float *d_output, const float *fc1_input, const float *W_fc1, float *d_input, float *d_W_fc1, float *d_b_fc1, int T, int aligned_in, int aligned_out, int num_threads)
 
void fc2_backward_kernel (const float *d_output, const float *fc2_input, const float *W_fc2, float *d_input, float *d_W_fc2, float *d_b_fc2, int T, int aligned_in, int aligned_out, int num_threads)
 
void mlp_token_parallel (const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads)
 
void mlp_token_parallel_exact (const float *input, const float *W_fc1, const float *b_fc1, const float *W_fc2, const float *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads)
 

Detailed Description

MLP (feed-forward) kernels with SIMD (SSE/AVX/AVX512)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

LEGACY EXCEPTION: This file contains OpenMP for backward compatibility. New kernels should NOT use OpenMP internally.

MLP: out = FC2(GELU(FC1(x)))

Definition in file mlp_kernels.c.

Function Documentation

◆ fc1_backward_kernel()

void fc1_backward_kernel ( const float *  d_output,
const float *  fc1_input,
const float *  W_fc1,
float *  d_input,
float *  d_W_fc1,
float *  d_b_fc1,
int  T,
int  aligned_in,
int  aligned_out,
int  num_threads 
)

Definition at line 167 of file mlp_kernels.c.

177 {
178  (void)num_threads; // Threading handled by GEMM kernels
179 
180  // 1. d_input[T, in] = d_output[T, out] @ W[out, in]
181  // Using gemm_nn: C[M,N] = A[M,K] @ B[K,N]
182  // A = d_output [T, out], B = W [out, in], C = d_input [T, in]
183  // M = T, N = aligned_in, K = aligned_out
184  gemm_nn_avx512(d_output, W_fc1, NULL, d_input,
185  T, aligned_in, aligned_out);
186 
187  // 2. d_W[out, in] = d_output[T, out].T @ fc1_input[T, in]
188  // Using gemm_tn: C[M,N] = A[K,M].T @ B[K,N]
189  // A = d_output [T, out] (stored as [K=T, M=out]), B = fc1_input [T, in]
190  // C = d_W [out, in], M = aligned_out, N = aligned_in, K = T
191  gemm_tn_avx512(d_output, fc1_input, NULL, d_W_fc1,
192  aligned_out, aligned_in, T);
193 
194  // 3. d_b_fc1 = sum_over_T(d_output)
195 #pragma omp parallel for schedule(static)
196  for (int out_idx = 0; out_idx < aligned_out; ++out_idx) {
197  float bias_grad = 0.0f;
198  for (int t = 0; t < T; ++t) {
199  bias_grad += d_output[(size_t)t * aligned_out + out_idx];
200  }
201  d_b_fc1[out_idx] += bias_grad;
202  }
203 }
void gemm_nn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:339
void gemm_tn_avx512(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:521

References gemm_nn_avx512(), and gemm_tn_avx512().

Referenced by ck_layer_backward_rmsnorm_swiglu().

◆ fc2_backward_kernel()

void fc2_backward_kernel ( const float *  d_output,
const float *  fc2_input,
const float *  W_fc2,
float *  d_input,
float *  d_W_fc2,
float *  d_b_fc2,
int  T,
int  aligned_in,
int  aligned_out,
int  num_threads 
)

Definition at line 118 of file mlp_kernels.c.

128 {
129  (void)num_threads; // Threading handled by GEMM kernels
130 
131  // 1. d_input[T, in] = d_output[T, out] @ W[out, in]
132  // Using gemm_nn: C[M,N] = A[M,K] @ B[K,N]
133  // A = d_output [T, out], B = W [out, in], C = d_input [T, in]
134  // M = T, N = aligned_in, K = aligned_out
135  gemm_nn_avx512(d_output, W_fc2, NULL, d_input,
136  T, aligned_in, aligned_out);
137 
138  // 2. d_W[out, in] = d_output[T, out].T @ fc2_input[T, in]
139  // Using gemm_tn: C[M,N] = A[K,M].T @ B[K,N]
140  // A = d_output [T, out] (stored as [K=T, M=out]), B = fc2_input [T, in]
141  // C = d_W [out, in], M = aligned_out, N = aligned_in, K = T
142  // Note: gemm_tn overwrites, so we need to save and add if accumulating
143  // For now, assume d_W starts zeroed (gradient accumulation handled at higher level)
144  gemm_tn_avx512(d_output, fc2_input, NULL, d_W_fc2,
145  aligned_out, aligned_in, T);
146 
147  // 3. d_b_fc2 = sum_over_T(d_output)
148 #pragma omp parallel for schedule(static)
149  for (int out_idx = 0; out_idx < aligned_out; ++out_idx) {
150  float bias_grad = 0.0f;
151  for (int t = 0; t < T; ++t) {
152  bias_grad += d_output[(size_t)t * aligned_out + out_idx];
153  }
154  d_b_fc2[out_idx] += bias_grad;
155  }
156 }

References gemm_nn_avx512(), and gemm_tn_avx512().

Referenced by ck_attention_project_head_major_backward(), ck_layer_backward_rmsnorm_swiglu(), and ck_qkv_project_head_major_backward().

◆ mlp_token_parallel()

void mlp_token_parallel ( const float *  input,
const float *  W_fc1,
const float *  b_fc1,
const float *  W_fc2,
const float *  b_fc2,
float *  fc1_output,
float *  output,
int  T,
int  aligned_dim,
int  num_threads 
)

Definition at line 41 of file mlp_kernels.c.

51 {
52  int D = aligned_dim;
53  int fourD = 4 * D;
54 
55  // FC1: [T × D] · [D × 4D] -> [T × 4D]
56  // Our GEMM layout: A[M×K], B[N×K], so B is [4D × D].
57  gemm_blocked_serial(input, W_fc1, b_fc1,
58  fc1_output,
59  T, // M
60  fourD, // N
61  D); // K
62 
63  // GELU in-place on FC1 output
64  gelu_fast_inplace(fc1_output, (size_t)T * (size_t)fourD);
65 
66  // FC2: [T × 4D] · [4D × D] -> [T × D]
67  gemm_blocked_serial(fc1_output, W_fc2, b_fc2,
68  output,
69  T, // M
70  D, // N
71  fourD); // K
72 }
void gelu_fast_inplace(float *data, size_t n)
Definition: gelu_kernels.c:132
void gemm_blocked_serial(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
Definition: gemm_kernels.c:661

References gelu_fast_inplace(), and gemm_blocked_serial().

◆ mlp_token_parallel_exact()

void mlp_token_parallel_exact ( const float *  input,
const float *  W_fc1,
const float *  b_fc1,
const float *  W_fc2,
const float *  b_fc2,
float *  fc1_output,
float *  output,
int  T,
int  aligned_dim,
int  num_threads 
)

Definition at line 76 of file mlp_kernels.c.

86 {
87  (void)num_threads;
88  int D = aligned_dim;
89  int fourD = 4 * D;
90 
91  // FC1: [T × D] · [D × 4D] -> [T × 4D]
92  gemm_blocked_serial(input, W_fc1, b_fc1,
93  fc1_output,
94  T, // M
95  fourD, // N
96  D); // K
97 
98  // Exact GELU using standard library tanhf
99  gelu_exact_inplace(fc1_output, (size_t)T * (size_t)fourD);
100 
101  // FC2: [T × 4D] · [4D × D] -> [T × D]
102  gemm_blocked_serial(fc1_output, W_fc2, b_fc2,
103  output,
104  T, // M
105  D, // N
106  fourD); // K
107 }
void gelu_exact_inplace(float *data, size_t n)
Definition: gelu_kernels.c:446

References gelu_exact_inplace(), and gemm_blocked_serial().