← Back to C-Kernel-Engine Docs Doxygen Source Documentation
softmax_kernels.c
Go to the documentation of this file.
1 /**
2  * @file softmax_kernels.c
3  * @brief Softmax forward/backward kernels with SIMD (SSE/AVX/AVX512)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Softmax: y[i] = exp(x[i] - max(x)) / sum(exp(x - max(x)))
15  */
16 
17 #include <math.h>
18 
19 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
20 #include <immintrin.h>
21 #endif
22 
23 /* Fast vectorized exp approximation (good for softmax, ~1e-4 relative error) */
24 // Based on Schraudolph's algorithm with improved coefficients
25 #if defined(__AVX512F__)
26 static inline __m512 exp512_approx(__m512 x) {
27  // Clamp to avoid overflow/underflow
28  x = _mm512_max_ps(x, _mm512_set1_ps(-88.0f));
29  x = _mm512_min_ps(x, _mm512_set1_ps(88.0f));
30 
31  // exp(x) = 2^(x * log2(e)) = 2^(x * 1.4426950408889634)
32  const __m512 log2e = _mm512_set1_ps(1.4426950408889634f);
33  const __m512 c1 = _mm512_set1_ps(0.693359375f);
34  const __m512 c2 = _mm512_set1_ps(-2.12194440e-4f);
35 
36  __m512 t = _mm512_mul_ps(x, log2e);
37  __m512 ti = _mm512_roundscale_ps(t, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
38 
39  // Reconstruct remainder: rx = x - ti * ln(2)
40  __m512 rx = _mm512_sub_ps(x, _mm512_mul_ps(ti, c1));
41  rx = _mm512_sub_ps(rx, _mm512_mul_ps(ti, c2));
42 
43  // Polynomial approximation for 2^tf on [-0.5, 0.5]
44  const __m512 p0 = _mm512_set1_ps(1.0f);
45  const __m512 p1 = _mm512_set1_ps(0.6931471805599453f);
46  const __m512 p2 = _mm512_set1_ps(0.24022650695910071f);
47  const __m512 p3 = _mm512_set1_ps(0.05550410866482157f);
48  const __m512 p4 = _mm512_set1_ps(0.009618129107628477f);
49 
50  __m512 poly = _mm512_fmadd_ps(p4, rx, p3);
51  poly = _mm512_fmadd_ps(poly, rx, p2);
52  poly = _mm512_fmadd_ps(poly, rx, p1);
53  poly = _mm512_fmadd_ps(poly, rx, p0);
54 
55  // Scale by 2^ti using integer manipulation
56  __m512i ti_int = _mm512_cvtps_epi32(ti);
57  ti_int = _mm512_add_epi32(ti_int, _mm512_set1_epi32(127));
58  ti_int = _mm512_slli_epi32(ti_int, 23);
59  __m512 scale = _mm512_castsi512_ps(ti_int);
60 
61  return _mm512_mul_ps(poly, scale);
62 }
63 #endif
64 
65 #if defined(__AVX2__)
66 // AVX2 version with integer operations
67 static inline __m256 exp256_approx(__m256 x) {
68  // Clamp to avoid overflow/underflow
69  x = _mm256_max_ps(x, _mm256_set1_ps(-88.0f));
70  x = _mm256_min_ps(x, _mm256_set1_ps(88.0f));
71 
72  const __m256 log2e = _mm256_set1_ps(1.4426950408889634f);
73  const __m256 c1 = _mm256_set1_ps(0.693359375f);
74  const __m256 c2 = _mm256_set1_ps(-2.12194440e-4f);
75 
76  __m256 t = _mm256_mul_ps(x, log2e);
77  __m256 ti = _mm256_round_ps(t, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
78 
79  __m256 rx = _mm256_sub_ps(x, _mm256_mul_ps(ti, c1));
80  rx = _mm256_sub_ps(rx, _mm256_mul_ps(ti, c2));
81 
82  // Polynomial (use FMA if available)
83  const __m256 p0 = _mm256_set1_ps(1.0f);
84  const __m256 p1 = _mm256_set1_ps(0.6931471805599453f);
85  const __m256 p2 = _mm256_set1_ps(0.24022650695910071f);
86  const __m256 p3 = _mm256_set1_ps(0.05550410866482157f);
87  const __m256 p4 = _mm256_set1_ps(0.009618129107628477f);
88 
89  __m256 poly = _mm256_fmadd_ps(p4, rx, p3);
90  poly = _mm256_fmadd_ps(poly, rx, p2);
91  poly = _mm256_fmadd_ps(poly, rx, p1);
92  poly = _mm256_fmadd_ps(poly, rx, p0);
93 
94  // Scale by 2^ti using AVX2 integer ops
95  __m256i ti_int = _mm256_cvtps_epi32(ti);
96  ti_int = _mm256_add_epi32(ti_int, _mm256_set1_epi32(127));
97  ti_int = _mm256_slli_epi32(ti_int, 23);
98  __m256 scale = _mm256_castsi256_ps(ti_int);
99 
100  return _mm256_mul_ps(poly, scale);
101 }
102 #endif
103 
104 // AVX/AVX2 horizontal max helper (works for both, uses 256-bit ops only)
105 #if defined(__AVX__) || defined(__AVX2__)
106 static inline float hmax256_ps(__m256 v) {
107  __m128 hi = _mm256_extractf128_ps(v, 1);
108  __m128 lo = _mm256_castps256_ps128(v);
109  __m128 max128 = _mm_max_ps(lo, hi);
110  max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, _MM_SHUFFLE(2, 3, 0, 1)));
111  max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, _MM_SHUFFLE(1, 0, 3, 2)));
112  return _mm_cvtss_f32(max128);
113 }
114 
115 // AVX/AVX2 horizontal sum helper
116 static inline float hsum256_ps_softmax(__m256 v) {
117  __m128 hi = _mm256_extractf128_ps(v, 1);
118  __m128 lo = _mm256_castps256_ps128(v);
119  __m128 sum128 = _mm_add_ps(lo, hi);
120  sum128 = _mm_hadd_ps(sum128, sum128);
121  sum128 = _mm_hadd_ps(sum128, sum128);
122  return _mm_cvtss_f32(sum128);
123 }
124 #endif
125 
126 // Causal softmax on head-major attention scores, copied and generalized
127 // from C-Transformer's apply_causal_softmax_head_major.
128 //
129 // scores layout: [head][query_token][key_token] with aligned_context_window stride:
130 // index = h * aligned_context_window * aligned_context_window
131 // + i * aligned_context_window
132 // + j
133 /**
134  * Causal softmax (in-place, row-wise)
135  * @test test_softmax.py::TestSoftmaxForward::test_causal_softmax
136  * @test test_softmax.py::TestSoftmaxForward::test_causal_vs_softmax
137  * @test test_attention.py::TestAttentionForward::test_softmax_correctness
138  *
139  * Applies causal mask (j > i => 0) and softmax to scores matrix.
140  * In-place on [num_heads, T, T] scores matrix.
141  *
142  * After changes: make test && make llamacpp-parity-full
143  */
144 void causal_softmax_head_major(float *scores,
145  int num_heads,
146  int num_tokens,
147  int aligned_context_window)
148 {
149  for (int h = 0; h < num_heads; ++h) {
150  for (int i = 0; i < num_tokens; ++i) {
151  int base = h * aligned_context_window * aligned_context_window
152  + i * aligned_context_window;
153  float *row = &scores[base];
154  int len = i + 1; // Number of valid elements (0..i inclusive)
155 
156 #if defined(__AVX512F__)
157  // Find max (vectorized)
158  __m512 max_vec = _mm512_set1_ps(-INFINITY);
159  int j = 0;
160  for (; j + 16 <= len; j += 16) {
161  __m512 v = _mm512_loadu_ps(&row[j]);
162  max_vec = _mm512_max_ps(max_vec, v);
163  }
164  float max_val = _mm512_reduce_max_ps(max_vec);
165  for (; j < len; ++j) {
166  if (row[j] > max_val) max_val = row[j];
167  }
168 
169  // Compute exp and sum (vectorized)
170  __m512 max_broadcast = _mm512_set1_ps(max_val);
171  __m512 sum_vec = _mm512_setzero_ps();
172  j = 0;
173  for (; j + 16 <= len; j += 16) {
174  __m512 v = _mm512_loadu_ps(&row[j]);
175  __m512 e = exp512_approx(_mm512_sub_ps(v, max_broadcast));
176  _mm512_storeu_ps(&row[j], e);
177  sum_vec = _mm512_add_ps(sum_vec, e);
178  }
179  float sum = _mm512_reduce_add_ps(sum_vec);
180  for (; j < len; ++j) {
181  float e = expf(row[j] - max_val);
182  row[j] = e;
183  sum += e;
184  }
185 
186  // Normalize (vectorized)
187  float inv_sum = 1.0f / sum;
188  __m512 inv_sum_vec = _mm512_set1_ps(inv_sum);
189  j = 0;
190  for (; j + 16 <= len; j += 16) {
191  __m512 v = _mm512_loadu_ps(&row[j]);
192  _mm512_storeu_ps(&row[j], _mm512_mul_ps(v, inv_sum_vec));
193  }
194  for (; j < len; ++j) {
195  row[j] *= inv_sum;
196  }
197 
198  // Zero out future tokens (vectorized)
199  __m512 zero = _mm512_setzero_ps();
200  for (; j + 16 <= num_tokens; j += 16) {
201  _mm512_storeu_ps(&row[j], zero);
202  }
203  for (; j < num_tokens; ++j) {
204  row[j] = 0.0f;
205  }
206 
207 #elif defined(__AVX2__)
208  // AVX2: Find max (vectorized)
209  __m256 max_vec = _mm256_set1_ps(-INFINITY);
210  int j = 0;
211  for (; j + 8 <= len; j += 8) {
212  __m256 v = _mm256_loadu_ps(&row[j]);
213  max_vec = _mm256_max_ps(max_vec, v);
214  }
215  float max_val = hmax256_ps(max_vec);
216  for (; j < len; ++j) {
217  if (row[j] > max_val) max_val = row[j];
218  }
219 
220  // Compute exp and sum (vectorized with fast exp)
221  __m256 max_broadcast = _mm256_set1_ps(max_val);
222  __m256 sum_vec = _mm256_setzero_ps();
223  j = 0;
224  for (; j + 8 <= len; j += 8) {
225  __m256 v = _mm256_loadu_ps(&row[j]);
226  __m256 e = exp256_approx(_mm256_sub_ps(v, max_broadcast));
227  _mm256_storeu_ps(&row[j], e);
228  sum_vec = _mm256_add_ps(sum_vec, e);
229  }
230  float sum = hsum256_ps_softmax(sum_vec);
231  for (; j < len; ++j) {
232  float e = expf(row[j] - max_val);
233  row[j] = e;
234  sum += e;
235  }
236 
237  // Normalize (vectorized)
238  float inv_sum = 1.0f / sum;
239  __m256 inv_sum_vec = _mm256_set1_ps(inv_sum);
240  j = 0;
241  for (; j + 8 <= len; j += 8) {
242  __m256 v = _mm256_loadu_ps(&row[j]);
243  _mm256_storeu_ps(&row[j], _mm256_mul_ps(v, inv_sum_vec));
244  }
245  for (; j < len; ++j) {
246  row[j] *= inv_sum;
247  }
248 
249  // Zero out future tokens (vectorized)
250  __m256 zero = _mm256_setzero_ps();
251  for (; j + 8 <= num_tokens; j += 8) {
252  _mm256_storeu_ps(&row[j], zero);
253  }
254  for (; j < num_tokens; ++j) {
255  row[j] = 0.0f;
256  }
257 
258 #elif defined(__AVX__)
259  // AVX1: vectorized max/sum/normalize, scalar exp
260  __m256 max_vec = _mm256_set1_ps(-INFINITY);
261  int j = 0;
262  for (; j + 8 <= len; j += 8) {
263  __m256 v = _mm256_loadu_ps(&row[j]);
264  max_vec = _mm256_max_ps(max_vec, v);
265  }
266  float max_val = hmax256_ps(max_vec);
267  for (; j < len; ++j) {
268  if (row[j] > max_val) max_val = row[j];
269  }
270 
271  // Compute exp and sum (scalar exp, no fast approx for AVX1)
272  float sum = 0.0f;
273  for (j = 0; j < len; ++j) {
274  float e = expf(row[j] - max_val);
275  row[j] = e;
276  sum += e;
277  }
278 
279  // Normalize (vectorized)
280  float inv_sum = 1.0f / sum;
281  __m256 inv_sum_vec = _mm256_set1_ps(inv_sum);
282  j = 0;
283  for (; j + 8 <= len; j += 8) {
284  __m256 v = _mm256_loadu_ps(&row[j]);
285  _mm256_storeu_ps(&row[j], _mm256_mul_ps(v, inv_sum_vec));
286  }
287  for (; j < len; ++j) {
288  row[j] *= inv_sum;
289  }
290 
291  // Zero out future tokens (vectorized)
292  __m256 zero = _mm256_setzero_ps();
293  for (; j + 8 <= num_tokens; j += 8) {
294  _mm256_storeu_ps(&row[j], zero);
295  }
296  for (; j < num_tokens; ++j) {
297  row[j] = 0.0f;
298  }
299 
300 #else
301  // Scalar fallback
302  float max_val = row[0];
303  for (int j = 1; j < len; ++j) {
304  if (row[j] > max_val) max_val = row[j];
305  }
306 
307  float sum = 0.0f;
308  for (int j = 0; j < len; ++j) {
309  float e = expf(row[j] - max_val);
310  row[j] = e;
311  sum += e;
312  }
313 
314  float inv_sum = 1.0f / sum;
315  for (int j = 0; j < len; ++j) {
316  row[j] *= inv_sum;
317  }
318 
319  for (int j = len; j < num_tokens; ++j) {
320  row[j] = 0.0f;
321  }
322 #endif
323  }
324  }
325 }
326 
327 // Scalar-only exact causal softmax using standard library expf.
328 // This is slower than causal_softmax_head_major but provides maximum accuracy.
329 // Used by BF16 attention wrapper where approximation error accumulates.
330 /**
331  * Causal softmax (exact version using stdlib expf)
332  * @test test_softmax.py::TestSoftmaxForward::test_causal_softmax_exact
333  * @test test_softmax.py::TestSoftmaxForward::test_exact_vs_fast
334  *
335  * Exact causal softmax using standard library expf for numerical accuracy reference.
336  *
337  * After changes: make test
338  */
340  int num_heads,
341  int num_tokens,
342  int aligned_context_window)
343 {
344  for (int h = 0; h < num_heads; ++h) {
345  for (int i = 0; i < num_tokens; ++i) {
346  int base = h * aligned_context_window * aligned_context_window
347  + i * aligned_context_window;
348  float *row = &scores[base];
349  int len = i + 1;
350 
351  // Find max
352  float max_val = -INFINITY;
353  for (int j = 0; j < len; ++j) {
354  if (row[j] > max_val) max_val = row[j];
355  }
356 
357  // Compute exp and sum using standard library expf
358  float sum = 0.0f;
359  for (int j = 0; j < len; ++j) {
360  float e = expf(row[j] - max_val);
361  row[j] = e;
362  sum += e;
363  }
364 
365  // Normalize
366  float inv_sum = 1.0f / sum;
367  for (int j = 0; j < len; ++j) {
368  row[j] *= inv_sum;
369  }
370 
371  // Zero out future tokens
372  for (int j = len; j < num_tokens; ++j) {
373  row[j] = 0.0f;
374  }
375  }
376  }
377 }
378 
379 // Backward pass for causal softmax on head-major scores, adapted from
380 // C-Transformer's backward_causal_softmax. Operates in-place on d_scores,
381 // using the cached forward softmax output `weights`.
383  const float *weights,
384  int num_heads,
385  int num_tokens,
386  int aligned_context_window)
387 {
388  int H = num_heads;
389  int T = num_tokens;
390 
391  for (int h = 0; h < H; ++h) {
392  for (int i = 0; i < T; ++i) {
393  int base = h * aligned_context_window * aligned_context_window
394  + i * aligned_context_window;
395  float *drow = &d_scores[base];
396  const float *wrow = &weights[base];
397  int len = i + 1;
398 
399 #if defined(__AVX512F__)
400  // Compute dot product (vectorized)
401  __m512 dot_vec = _mm512_setzero_ps();
402  int j = 0;
403  for (; j + 16 <= len; j += 16) {
404  __m512 w = _mm512_loadu_ps(&wrow[j]);
405  __m512 dw = _mm512_loadu_ps(&drow[j]);
406  dot_vec = _mm512_fmadd_ps(w, dw, dot_vec);
407  }
408  float dot_product = _mm512_reduce_add_ps(dot_vec);
409  for (; j < len; ++j) {
410  dot_product += wrow[j] * drow[j];
411  }
412 
413  // Compute gradient: d_scores = w * (dw - dot_product)
414  __m512 dot_broadcast = _mm512_set1_ps(dot_product);
415  j = 0;
416  for (; j + 16 <= len; j += 16) {
417  __m512 w = _mm512_loadu_ps(&wrow[j]);
418  __m512 dw = _mm512_loadu_ps(&drow[j]);
419  __m512 diff = _mm512_sub_ps(dw, dot_broadcast);
420  __m512 result = _mm512_mul_ps(w, diff);
421  _mm512_storeu_ps(&drow[j], result);
422  }
423  for (; j < len; ++j) {
424  drow[j] = wrow[j] * (drow[j] - dot_product);
425  }
426 
427  // Zero out future tokens
428  __m512 zero = _mm512_setzero_ps();
429  for (; j + 16 <= T; j += 16) {
430  _mm512_storeu_ps(&drow[j], zero);
431  }
432  for (; j < T; ++j) {
433  drow[j] = 0.0f;
434  }
435 
436 #elif defined(__AVX__)
437  // Compute dot product (vectorized)
438  __m256 dot_vec = _mm256_setzero_ps();
439  int j = 0;
440  for (; j + 8 <= len; j += 8) {
441  __m256 w = _mm256_loadu_ps(&wrow[j]);
442  __m256 dw = _mm256_loadu_ps(&drow[j]);
443  // No FMA in AVX1: use mul + add
444  __m256 prod = _mm256_mul_ps(w, dw);
445  dot_vec = _mm256_add_ps(dot_vec, prod);
446  }
447  float dot_product = hsum256_ps_softmax(dot_vec);
448  for (; j < len; ++j) {
449  dot_product += wrow[j] * drow[j];
450  }
451 
452  // Compute gradient: d_scores = w * (dw - dot_product)
453  __m256 dot_broadcast = _mm256_set1_ps(dot_product);
454  j = 0;
455  for (; j + 8 <= len; j += 8) {
456  __m256 w = _mm256_loadu_ps(&wrow[j]);
457  __m256 dw = _mm256_loadu_ps(&drow[j]);
458  __m256 diff = _mm256_sub_ps(dw, dot_broadcast);
459  __m256 result = _mm256_mul_ps(w, diff);
460  _mm256_storeu_ps(&drow[j], result);
461  }
462  for (; j < len; ++j) {
463  drow[j] = wrow[j] * (drow[j] - dot_product);
464  }
465 
466  // Zero out future tokens
467  __m256 zero = _mm256_setzero_ps();
468  for (; j + 8 <= T; j += 8) {
469  _mm256_storeu_ps(&drow[j], zero);
470  }
471  for (; j < T; ++j) {
472  drow[j] = 0.0f;
473  }
474 
475 #else
476  // Scalar fallback
477  float dot_product = 0.0f;
478  for (int j = 0; j < len; ++j) {
479  dot_product += wrow[j] * drow[j];
480  }
481 
482  for (int j = 0; j < len; ++j) {
483  drow[j] = wrow[j] * (drow[j] - dot_product);
484  }
485 
486  for (int j = len; j < T; ++j) {
487  drow[j] = 0.0f;
488  }
489 #endif
490  }
491  }
492 }
493 
void backward_causal_softmax_head_major(float *d_scores, const float *weights, int num_heads, int num_tokens, int aligned_context_window)
void causal_softmax_head_major_exact(float *scores, int num_heads, int num_tokens, int aligned_context_window)
void causal_softmax_head_major(float *scores, int num_heads, int num_tokens, int aligned_context_window)