← Back to C-Kernel-Engine Docs Doxygen Source Documentation
swiglu_kernels_bf16.c
Go to the documentation of this file.
1 /**
2  * @file swiglu_kernels_bf16.c
3  * @brief SwiGLU activation kernels for BF16 tensors
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * SwiGLU: y = silu(gate) * up = (gate * sigmoid(gate)) * up
15  */
16 
17 #include <stdint.h>
18 #include <math.h>
19 
20 #include "bf16_utils.h"
21 #include "ckernel_engine.h"
22 
23 #if defined(__AVX512F__)
24 #include <immintrin.h>
25 
26 /* Fast exp approximation for AVX-512 */
27 static inline __m512 exp512_fast_bf16(__m512 x) {
28  // Clamp to avoid overflow/underflow
29  x = _mm512_max_ps(x, _mm512_set1_ps(-88.0f));
30  x = _mm512_min_ps(x, _mm512_set1_ps(88.0f));
31 
32  const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
33  __m512 z = _mm512_mul_ps(x, log2e);
34  __m512 zf = _mm512_roundscale_ps(z, _MM_FROUND_TO_NEAREST_INT);
35  __m512 f = _mm512_sub_ps(z, zf);
36 
37  // Polynomial for 2^f
38  const __m512 c0 = _mm512_set1_ps(1.0f);
39  const __m512 c1 = _mm512_set1_ps(0.6931471805599453f);
40  const __m512 c2 = _mm512_set1_ps(0.2402265069591007f);
41  const __m512 c3 = _mm512_set1_ps(0.05550410866482158f);
42  const __m512 c4 = _mm512_set1_ps(0.009618129107628478f);
43 
44  __m512 poly = _mm512_fmadd_ps(f, c4, c3);
45  poly = _mm512_fmadd_ps(f, poly, c2);
46  poly = _mm512_fmadd_ps(f, poly, c1);
47  poly = _mm512_fmadd_ps(f, poly, c0);
48 
49  __m512i zi = _mm512_cvtps_epi32(zf);
50  zi = _mm512_add_epi32(zi, _mm512_set1_epi32(127));
51  zi = _mm512_slli_epi32(zi, 23);
52  __m512 scale = _mm512_castsi512_ps(zi);
53 
54  return _mm512_mul_ps(poly, scale);
55 }
56 
57 // AVX-512 sigmoid: 1 / (1 + exp(-x))
58 static inline __m512 sigmoid512_fast_bf16(__m512 x) {
59  __m512 neg_x = _mm512_sub_ps(_mm512_setzero_ps(), x);
60  __m512 exp_neg = exp512_fast_bf16(neg_x);
61  __m512 one = _mm512_set1_ps(1.0f);
62  return _mm512_div_ps(one, _mm512_add_ps(one, exp_neg));
63 }
64 #endif
65 
66 void swiglu_forward_bf16(const uint16_t *input,
67  uint16_t *output,
68  int tokens,
69  int dim)
70 {
71  if (!input || !output || tokens <= 0 || dim <= 0) {
72  return;
73  }
74 
75  const int T = tokens;
76  const int D = dim;
77 
78  for (int t = 0; t < T; ++t) {
79  const uint16_t *row = input + (size_t)t * (size_t)(2 * D);
80  uint16_t *out_row = output + (size_t)t * (size_t)D;
81  int d = 0;
82 
83 #if defined(__AVX512F__)
84  // AVX-512: Process 16 floats at a time
85  for (; d + 16 <= D; d += 16) {
86  __m512 a = bf16_loadu_cvt_fp32(&row[d]); // gate
87  __m512 b = bf16_loadu_cvt_fp32(&row[D + d]); // value
88 
89  __m512 s = sigmoid512_fast_bf16(a); // sigmoid(a)
90  __m512 silu = _mm512_mul_ps(a, s); // silu(a) = a * sigmoid(a)
91  __m512 y = _mm512_mul_ps(silu, b); // y = silu(a) * b
92 
93  fp32_cvt_storeu_bf16(&out_row[d], y);
94  }
95 #endif
96 
97  // Scalar fallback for remaining elements
98  for (; d < D; ++d) {
99  float a = bf16_to_float(row[d]);
100  float b = bf16_to_float(row[D + d]);
101  float s = sigmoid_scalar(a);
102  float silu = a * s;
103  out_row[d] = float_to_bf16(silu * b);
104  }
105  }
106 }
107 
108 void swiglu_backward_bf16(const uint16_t *input,
109  const uint16_t *d_output,
110  uint16_t *d_input,
111  int tokens,
112  int dim)
113 {
114  if (!input || !d_output || !d_input || tokens <= 0 || dim <= 0) {
115  return;
116  }
117 
118  const int T = tokens;
119  const int D = dim;
120 
121  for (int t = 0; t < T; ++t) {
122  const uint16_t *row = input + (size_t)t * (size_t)(2 * D);
123  const uint16_t *dy_row = d_output + (size_t)t * (size_t)D;
124  uint16_t *dx_row = d_input + (size_t)t * (size_t)(2 * D);
125  int d = 0;
126 
127 #if defined(__AVX512F__)
128  // AVX-512: Process 16 floats at a time
129  __m512 one = _mm512_set1_ps(1.0f);
130  for (; d + 16 <= D; d += 16) {
131  __m512 a = bf16_loadu_cvt_fp32(&row[d]); // gate
132  __m512 b = bf16_loadu_cvt_fp32(&row[D + d]); // value
133  __m512 dy = bf16_loadu_cvt_fp32(&dy_row[d]);
134 
135  __m512 s = sigmoid512_fast_bf16(a); // sigmoid(a)
136  __m512 silu = _mm512_mul_ps(a, s); // silu(a) = a * s
137  __m512 s_prime = _mm512_mul_ps(s, _mm512_sub_ps(one, s)); // s * (1 - s)
138  __m512 silu_prime = _mm512_fmadd_ps(a, s_prime, s); // s + a * s_prime
139 
140  // dA = dy * b * silu_prime
141  __m512 dA = _mm512_mul_ps(dy, _mm512_mul_ps(b, silu_prime));
142  // dB = dy * silu
143  __m512 dB = _mm512_mul_ps(dy, silu);
144 
145  fp32_cvt_storeu_bf16(&dx_row[d], dA);
146  fp32_cvt_storeu_bf16(&dx_row[D + d], dB);
147  }
148 #endif
149 
150  // Scalar fallback for remaining elements
151  for (; d < D; ++d) {
152  float a = bf16_to_float(row[d]);
153  float b = bf16_to_float(row[D + d]);
154  float dy = bf16_to_float(dy_row[d]);
155 
156  float s = sigmoid_scalar(a);
157  float silu = a * s;
158  float s_prime = s * (1.0f - s);
159  float silu_prime = s + a * s_prime;
160 
161  float dA = dy * b * silu_prime;
162  float dB = dy * silu;
163 
164  dx_row[d] = float_to_bf16(dA);
165  dx_row[D + d] = float_to_bf16(dB);
166  }
167  }
168 }
static uint16_t float_to_bf16(float f)
Definition: bf16_utils.h:90
static float bf16_to_float(uint16_t v)
Definition: bf16_utils.h:38
float sigmoid_scalar(float x)
void swiglu_forward_bf16(const uint16_t *input, uint16_t *output, int tokens, int dim)
void swiglu_backward_bf16(const uint16_t *input, const uint16_t *d_output, uint16_t *d_input, int tokens, int dim)
static void silu(float *x, int n)
Definition: v6_simple.c:159