← Back to C-Kernel-Engine Docs Doxygen Source Documentation
softmax_kernels.c File Reference

Softmax forward/backward kernels with SIMD (SSE/AVX/AVX512) More...

#include <math.h>

Go to the source code of this file.

Functions

void backward_causal_softmax_head_major (float *d_scores, const float *weights, int num_heads, int num_tokens, int aligned_context_window)
 
void causal_softmax_head_major (float *scores, int num_heads, int num_tokens, int aligned_context_window)
 
void causal_softmax_head_major_exact (float *scores, int num_heads, int num_tokens, int aligned_context_window)
 

Detailed Description

Softmax forward/backward kernels with SIMD (SSE/AVX/AVX512)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

Softmax: y[i] = exp(x[i] - max(x)) / sum(exp(x - max(x)))

Definition in file softmax_kernels.c.

Function Documentation

◆ backward_causal_softmax_head_major()

void backward_causal_softmax_head_major ( float *  d_scores,
const float *  weights,
int  num_heads,
int  num_tokens,
int  aligned_context_window 
)

Definition at line 382 of file softmax_kernels.c.

387 {
388  int H = num_heads;
389  int T = num_tokens;
390 
391  for (int h = 0; h < H; ++h) {
392  for (int i = 0; i < T; ++i) {
393  int base = h * aligned_context_window * aligned_context_window
394  + i * aligned_context_window;
395  float *drow = &d_scores[base];
396  const float *wrow = &weights[base];
397  int len = i + 1;
398 
399 #if defined(__AVX512F__)
400  // Compute dot product (vectorized)
401  __m512 dot_vec = _mm512_setzero_ps();
402  int j = 0;
403  for (; j + 16 <= len; j += 16) {
404  __m512 w = _mm512_loadu_ps(&wrow[j]);
405  __m512 dw = _mm512_loadu_ps(&drow[j]);
406  dot_vec = _mm512_fmadd_ps(w, dw, dot_vec);
407  }
408  float dot_product = _mm512_reduce_add_ps(dot_vec);
409  for (; j < len; ++j) {
410  dot_product += wrow[j] * drow[j];
411  }
412 
413  // Compute gradient: d_scores = w * (dw - dot_product)
414  __m512 dot_broadcast = _mm512_set1_ps(dot_product);
415  j = 0;
416  for (; j + 16 <= len; j += 16) {
417  __m512 w = _mm512_loadu_ps(&wrow[j]);
418  __m512 dw = _mm512_loadu_ps(&drow[j]);
419  __m512 diff = _mm512_sub_ps(dw, dot_broadcast);
420  __m512 result = _mm512_mul_ps(w, diff);
421  _mm512_storeu_ps(&drow[j], result);
422  }
423  for (; j < len; ++j) {
424  drow[j] = wrow[j] * (drow[j] - dot_product);
425  }
426 
427  // Zero out future tokens
428  __m512 zero = _mm512_setzero_ps();
429  for (; j + 16 <= T; j += 16) {
430  _mm512_storeu_ps(&drow[j], zero);
431  }
432  for (; j < T; ++j) {
433  drow[j] = 0.0f;
434  }
435 
436 #elif defined(__AVX__)
437  // Compute dot product (vectorized)
438  __m256 dot_vec = _mm256_setzero_ps();
439  int j = 0;
440  for (; j + 8 <= len; j += 8) {
441  __m256 w = _mm256_loadu_ps(&wrow[j]);
442  __m256 dw = _mm256_loadu_ps(&drow[j]);
443  // No FMA in AVX1: use mul + add
444  __m256 prod = _mm256_mul_ps(w, dw);
445  dot_vec = _mm256_add_ps(dot_vec, prod);
446  }
447  float dot_product = hsum256_ps_softmax(dot_vec);
448  for (; j < len; ++j) {
449  dot_product += wrow[j] * drow[j];
450  }
451 
452  // Compute gradient: d_scores = w * (dw - dot_product)
453  __m256 dot_broadcast = _mm256_set1_ps(dot_product);
454  j = 0;
455  for (; j + 8 <= len; j += 8) {
456  __m256 w = _mm256_loadu_ps(&wrow[j]);
457  __m256 dw = _mm256_loadu_ps(&drow[j]);
458  __m256 diff = _mm256_sub_ps(dw, dot_broadcast);
459  __m256 result = _mm256_mul_ps(w, diff);
460  _mm256_storeu_ps(&drow[j], result);
461  }
462  for (; j < len; ++j) {
463  drow[j] = wrow[j] * (drow[j] - dot_product);
464  }
465 
466  // Zero out future tokens
467  __m256 zero = _mm256_setzero_ps();
468  for (; j + 8 <= T; j += 8) {
469  _mm256_storeu_ps(&drow[j], zero);
470  }
471  for (; j < T; ++j) {
472  drow[j] = 0.0f;
473  }
474 
475 #else
476  // Scalar fallback
477  float dot_product = 0.0f;
478  for (int j = 0; j < len; ++j) {
479  dot_product += wrow[j] * drow[j];
480  }
481 
482  for (int j = 0; j < len; ++j) {
483  drow[j] = wrow[j] * (drow[j] - dot_product);
484  }
485 
486  for (int j = len; j < T; ++j) {
487  drow[j] = 0.0f;
488  }
489 #endif
490  }
491  }
492 }

Referenced by backward_causal_softmax_head_major_bf16().

◆ causal_softmax_head_major()

void causal_softmax_head_major ( float *  scores,
int  num_heads,
int  num_tokens,
int  aligned_context_window 
)

Causal softmax (in-place, row-wise)

Test:

test_softmax.py::TestSoftmaxForward::test_causal_softmax

test_softmax.py::TestSoftmaxForward::test_causal_vs_softmax

test_attention.py::TestAttentionForward::test_softmax_correctness

Applies causal mask (j > i => 0) and softmax to scores matrix. In-place on [num_heads, T, T] scores matrix.

After changes: make test && make llamacpp-parity-full

Definition at line 144 of file softmax_kernels.c.

148 {
149  for (int h = 0; h < num_heads; ++h) {
150  for (int i = 0; i < num_tokens; ++i) {
151  int base = h * aligned_context_window * aligned_context_window
152  + i * aligned_context_window;
153  float *row = &scores[base];
154  int len = i + 1; // Number of valid elements (0..i inclusive)
155 
156 #if defined(__AVX512F__)
157  // Find max (vectorized)
158  __m512 max_vec = _mm512_set1_ps(-INFINITY);
159  int j = 0;
160  for (; j + 16 <= len; j += 16) {
161  __m512 v = _mm512_loadu_ps(&row[j]);
162  max_vec = _mm512_max_ps(max_vec, v);
163  }
164  float max_val = _mm512_reduce_max_ps(max_vec);
165  for (; j < len; ++j) {
166  if (row[j] > max_val) max_val = row[j];
167  }
168 
169  // Compute exp and sum (vectorized)
170  __m512 max_broadcast = _mm512_set1_ps(max_val);
171  __m512 sum_vec = _mm512_setzero_ps();
172  j = 0;
173  for (; j + 16 <= len; j += 16) {
174  __m512 v = _mm512_loadu_ps(&row[j]);
175  __m512 e = exp512_approx(_mm512_sub_ps(v, max_broadcast));
176  _mm512_storeu_ps(&row[j], e);
177  sum_vec = _mm512_add_ps(sum_vec, e);
178  }
179  float sum = _mm512_reduce_add_ps(sum_vec);
180  for (; j < len; ++j) {
181  float e = expf(row[j] - max_val);
182  row[j] = e;
183  sum += e;
184  }
185 
186  // Normalize (vectorized)
187  float inv_sum = 1.0f / sum;
188  __m512 inv_sum_vec = _mm512_set1_ps(inv_sum);
189  j = 0;
190  for (; j + 16 <= len; j += 16) {
191  __m512 v = _mm512_loadu_ps(&row[j]);
192  _mm512_storeu_ps(&row[j], _mm512_mul_ps(v, inv_sum_vec));
193  }
194  for (; j < len; ++j) {
195  row[j] *= inv_sum;
196  }
197 
198  // Zero out future tokens (vectorized)
199  __m512 zero = _mm512_setzero_ps();
200  for (; j + 16 <= num_tokens; j += 16) {
201  _mm512_storeu_ps(&row[j], zero);
202  }
203  for (; j < num_tokens; ++j) {
204  row[j] = 0.0f;
205  }
206 
207 #elif defined(__AVX2__)
208  // AVX2: Find max (vectorized)
209  __m256 max_vec = _mm256_set1_ps(-INFINITY);
210  int j = 0;
211  for (; j + 8 <= len; j += 8) {
212  __m256 v = _mm256_loadu_ps(&row[j]);
213  max_vec = _mm256_max_ps(max_vec, v);
214  }
215  float max_val = hmax256_ps(max_vec);
216  for (; j < len; ++j) {
217  if (row[j] > max_val) max_val = row[j];
218  }
219 
220  // Compute exp and sum (vectorized with fast exp)
221  __m256 max_broadcast = _mm256_set1_ps(max_val);
222  __m256 sum_vec = _mm256_setzero_ps();
223  j = 0;
224  for (; j + 8 <= len; j += 8) {
225  __m256 v = _mm256_loadu_ps(&row[j]);
226  __m256 e = exp256_approx(_mm256_sub_ps(v, max_broadcast));
227  _mm256_storeu_ps(&row[j], e);
228  sum_vec = _mm256_add_ps(sum_vec, e);
229  }
230  float sum = hsum256_ps_softmax(sum_vec);
231  for (; j < len; ++j) {
232  float e = expf(row[j] - max_val);
233  row[j] = e;
234  sum += e;
235  }
236 
237  // Normalize (vectorized)
238  float inv_sum = 1.0f / sum;
239  __m256 inv_sum_vec = _mm256_set1_ps(inv_sum);
240  j = 0;
241  for (; j + 8 <= len; j += 8) {
242  __m256 v = _mm256_loadu_ps(&row[j]);
243  _mm256_storeu_ps(&row[j], _mm256_mul_ps(v, inv_sum_vec));
244  }
245  for (; j < len; ++j) {
246  row[j] *= inv_sum;
247  }
248 
249  // Zero out future tokens (vectorized)
250  __m256 zero = _mm256_setzero_ps();
251  for (; j + 8 <= num_tokens; j += 8) {
252  _mm256_storeu_ps(&row[j], zero);
253  }
254  for (; j < num_tokens; ++j) {
255  row[j] = 0.0f;
256  }
257 
258 #elif defined(__AVX__)
259  // AVX1: vectorized max/sum/normalize, scalar exp
260  __m256 max_vec = _mm256_set1_ps(-INFINITY);
261  int j = 0;
262  for (; j + 8 <= len; j += 8) {
263  __m256 v = _mm256_loadu_ps(&row[j]);
264  max_vec = _mm256_max_ps(max_vec, v);
265  }
266  float max_val = hmax256_ps(max_vec);
267  for (; j < len; ++j) {
268  if (row[j] > max_val) max_val = row[j];
269  }
270 
271  // Compute exp and sum (scalar exp, no fast approx for AVX1)
272  float sum = 0.0f;
273  for (j = 0; j < len; ++j) {
274  float e = expf(row[j] - max_val);
275  row[j] = e;
276  sum += e;
277  }
278 
279  // Normalize (vectorized)
280  float inv_sum = 1.0f / sum;
281  __m256 inv_sum_vec = _mm256_set1_ps(inv_sum);
282  j = 0;
283  for (; j + 8 <= len; j += 8) {
284  __m256 v = _mm256_loadu_ps(&row[j]);
285  _mm256_storeu_ps(&row[j], _mm256_mul_ps(v, inv_sum_vec));
286  }
287  for (; j < len; ++j) {
288  row[j] *= inv_sum;
289  }
290 
291  // Zero out future tokens (vectorized)
292  __m256 zero = _mm256_setzero_ps();
293  for (; j + 8 <= num_tokens; j += 8) {
294  _mm256_storeu_ps(&row[j], zero);
295  }
296  for (; j < num_tokens; ++j) {
297  row[j] = 0.0f;
298  }
299 
300 #else
301  // Scalar fallback
302  float max_val = row[0];
303  for (int j = 1; j < len; ++j) {
304  if (row[j] > max_val) max_val = row[j];
305  }
306 
307  float sum = 0.0f;
308  for (int j = 0; j < len; ++j) {
309  float e = expf(row[j] - max_val);
310  row[j] = e;
311  sum += e;
312  }
313 
314  float inv_sum = 1.0f / sum;
315  for (int j = 0; j < len; ++j) {
316  row[j] *= inv_sum;
317  }
318 
319  for (int j = len; j < num_tokens; ++j) {
320  row[j] = 0.0f;
321  }
322 #endif
323  }
324  }
325 }

Referenced by attention_forward_causal_head_major(), attention_forward_causal_head_major_gqa(), and causal_softmax_head_major_bf16().

◆ causal_softmax_head_major_exact()

void causal_softmax_head_major_exact ( float *  scores,
int  num_heads,
int  num_tokens,
int  aligned_context_window 
)

Causal softmax (exact version using stdlib expf)

Test:

test_softmax.py::TestSoftmaxForward::test_causal_softmax_exact

test_softmax.py::TestSoftmaxForward::test_exact_vs_fast

Exact causal softmax using standard library expf for numerical accuracy reference.

After changes: make test

Definition at line 339 of file softmax_kernels.c.

343 {
344  for (int h = 0; h < num_heads; ++h) {
345  for (int i = 0; i < num_tokens; ++i) {
346  int base = h * aligned_context_window * aligned_context_window
347  + i * aligned_context_window;
348  float *row = &scores[base];
349  int len = i + 1;
350 
351  // Find max
352  float max_val = -INFINITY;
353  for (int j = 0; j < len; ++j) {
354  if (row[j] > max_val) max_val = row[j];
355  }
356 
357  // Compute exp and sum using standard library expf
358  float sum = 0.0f;
359  for (int j = 0; j < len; ++j) {
360  float e = expf(row[j] - max_val);
361  row[j] = e;
362  sum += e;
363  }
364 
365  // Normalize
366  float inv_sum = 1.0f / sum;
367  for (int j = 0; j < len; ++j) {
368  row[j] *= inv_sum;
369  }
370 
371  // Zero out future tokens
372  for (int j = len; j < num_tokens; ++j) {
373  row[j] = 0.0f;
374  }
375  }
376  }
377 }

Referenced by attention_forward_causal_head_major_exact(), and attention_forward_causal_head_major_gqa_exact().