← Back to C-Kernel-Engine Docs Doxygen Source Documentation
topk_kernels.c
Go to the documentation of this file.
1 /**
2  * @file topk_kernels.c
3  * @brief Top-K selection kernels for MoE router dispatch
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Provides efficient top-K selection from a score vector.
15  * Used in Mixture-of-Experts models to select which experts process each token.
16  *
17  * Operations:
18  * - topk_f32: Find top-K indices and values from N scores
19  * - topk_softmax_f32: Top-K with softmax normalization of selected scores
20  */
21 
22 #include <stdint.h>
23 #include <stddef.h>
24 #include <float.h>
25 #include <math.h>
26 
27 #ifdef __AVX512F__
28 #include <immintrin.h>
29 #endif
30 
31 /* =============================================================================
32  * Top-K Selection (scalar reference)
33  *
34  * Finds the K largest values in an array and returns their indices and values.
35  * Uses a simple min-heap approach: maintain K best, replace minimum when better found.
36  *
37  * For small K (typical MoE: K=2-8), this is efficient. O(N*K) complexity.
38  * ============================================================================= */
39 
40 /**
41  * @brief Find top-K indices and values from a score vector
42  *
43  * @param scores Input scores [n]
44  * @param n Number of scores (e.g., number of experts)
45  * @param k Number of top scores to select
46  * @param indices Output: indices of top-K scores [k], sorted descending by value
47  * @param values Output: top-K score values [k], sorted descending (can be NULL)
48  */
49 void topk_f32(const float *scores,
50  int n,
51  int k,
52  int *indices,
53  float *values)
54 {
55  if (!scores || !indices || n <= 0 || k <= 0) {
56  return;
57  }
58 
59  /* Clamp k to n */
60  if (k > n) {
61  k = n;
62  }
63 
64  /* Initialize with first k elements */
65  float local_values[k];
66  for (int i = 0; i < k; i++) {
67  indices[i] = i;
68  local_values[i] = scores[i];
69  }
70 
71  /* Find the minimum in our current top-k */
72  int min_idx = 0;
73  for (int i = 1; i < k; i++) {
74  if (local_values[i] < local_values[min_idx]) {
75  min_idx = i;
76  }
77  }
78 
79  /* Scan remaining elements */
80  for (int i = k; i < n; i++) {
81  if (scores[i] > local_values[min_idx]) {
82  /* Replace the minimum */
83  indices[min_idx] = i;
84  local_values[min_idx] = scores[i];
85 
86  /* Find new minimum */
87  min_idx = 0;
88  for (int j = 1; j < k; j++) {
89  if (local_values[j] < local_values[min_idx]) {
90  min_idx = j;
91  }
92  }
93  }
94  }
95 
96  /* Sort results in descending order (simple insertion sort for small k) */
97  for (int i = 1; i < k; i++) {
98  float val = local_values[i];
99  int idx = indices[i];
100  int j = i - 1;
101  while (j >= 0 && local_values[j] < val) {
102  local_values[j + 1] = local_values[j];
103  indices[j + 1] = indices[j];
104  j--;
105  }
106  local_values[j + 1] = val;
107  indices[j + 1] = idx;
108  }
109 
110  /* Copy values if output requested */
111  if (values) {
112  for (int i = 0; i < k; i++) {
113  values[i] = local_values[i];
114  }
115  }
116 }
117 
118 /* =============================================================================
119  * Top-K with Softmax Normalization
120  *
121  * Finds top-K and normalizes the selected scores using softmax.
122  * This is the standard MoE gating: select experts, then compute routing weights.
123  * ============================================================================= */
124 
125 /**
126  * @brief Find top-K indices with softmax-normalized weights
127  *
128  * @param scores Input scores [n] (router logits)
129  * @param n Number of scores
130  * @param k Number of top scores to select
131  * @param indices Output: indices of top-K scores [k]
132  * @param weights Output: softmax-normalized weights for selected [k], sum to 1.0
133  */
134 void topk_softmax_f32(const float *scores,
135  int n,
136  int k,
137  int *indices,
138  float *weights)
139 {
140  if (!scores || !indices || !weights || n <= 0 || k <= 0) {
141  return;
142  }
143 
144  if (k > n) {
145  k = n;
146  }
147 
148  /* First get top-K indices and values */
149  float values[k];
150  topk_f32(scores, n, k, indices, values);
151 
152  /* Compute softmax over the selected values */
153  /* Find max for numerical stability */
154  float max_val = values[0];
155  for (int i = 1; i < k; i++) {
156  if (values[i] > max_val) {
157  max_val = values[i];
158  }
159  }
160 
161  /* Compute exp and sum */
162  float sum = 0.0f;
163  for (int i = 0; i < k; i++) {
164  weights[i] = expf(values[i] - max_val);
165  sum += weights[i];
166  }
167 
168  /* Normalize */
169  float inv_sum = 1.0f / sum;
170  for (int i = 0; i < k; i++) {
171  weights[i] *= inv_sum;
172  }
173 }
174 
175 /* =============================================================================
176  * Batched Top-K (for multiple tokens)
177  *
178  * Process multiple tokens at once, each with its own routing scores.
179  * ============================================================================= */
180 
181 /**
182  * @brief Batched top-K selection for multiple tokens
183  *
184  * @param scores Input scores [num_tokens, n_experts]
185  * @param num_tokens Number of tokens
186  * @param n_experts Number of experts
187  * @param k Number of experts to select per token
188  * @param indices Output: selected expert indices [num_tokens, k]
189  * @param weights Output: routing weights [num_tokens, k] (can be NULL for no softmax)
190  */
191 void topk_batched_f32(const float *scores,
192  int num_tokens,
193  int n_experts,
194  int k,
195  int *indices,
196  float *weights)
197 {
198  if (!scores || !indices || num_tokens <= 0 || n_experts <= 0 || k <= 0) {
199  return;
200  }
201 
202  for (int t = 0; t < num_tokens; t++) {
203  const float *token_scores = scores + t * n_experts;
204  int *token_indices = indices + t * k;
205 
206  if (weights) {
207  float *token_weights = weights + t * k;
208  topk_softmax_f32(token_scores, n_experts, k, token_indices, token_weights);
209  } else {
210  topk_f32(token_scores, n_experts, k, token_indices, NULL);
211  }
212  }
213 }
214 
215 /* =============================================================================
216  * Argmax (special case of top-1)
217  * ============================================================================= */
218 
219 /**
220  * @brief Find index of maximum value
221  *
222  * @param scores Input scores [n]
223  * @param n Number of scores
224  * @return Index of maximum value
225  */
226 int argmax_f32(const float *scores, int n)
227 {
228  if (!scores || n <= 0) {
229  return -1;
230  }
231 
232  int max_idx = 0;
233  float max_val = scores[0];
234 
235 #ifdef __AVX512F__
236  /* AVX-512 vectorized argmax for large arrays */
237  if (n >= 16) {
238  __m512 vmax = _mm512_set1_ps(-FLT_MAX);
239  __m512i vidx = _mm512_setzero_si512();
240  __m512i vcur_max_idx = _mm512_setzero_si512();
241 
242  int i = 0;
243  for (; i + 16 <= n; i += 16) {
244  __m512 v = _mm512_loadu_ps(&scores[i]);
245  __m512i cur_idx = _mm512_add_epi32(
246  _mm512_set1_epi32(i),
247  _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
248  );
249 
250  __mmask16 gt_mask = _mm512_cmp_ps_mask(v, vmax, _CMP_GT_OQ);
251  vmax = _mm512_mask_blend_ps(gt_mask, vmax, v);
252  vcur_max_idx = _mm512_mask_blend_epi32(gt_mask, vcur_max_idx, cur_idx);
253  }
254 
255  /* Horizontal reduction */
256  float vals[16];
257  int idxs[16];
258  _mm512_storeu_ps(vals, vmax);
259  _mm512_storeu_si512(idxs, vcur_max_idx);
260 
261  max_val = vals[0];
262  max_idx = idxs[0];
263  for (int j = 1; j < 16; j++) {
264  if (vals[j] > max_val) {
265  max_val = vals[j];
266  max_idx = idxs[j];
267  }
268  }
269 
270  /* Handle remainder */
271  for (; i < n; i++) {
272  if (scores[i] > max_val) {
273  max_val = scores[i];
274  max_idx = i;
275  }
276  }
277 
278  return max_idx;
279  }
280 #endif
281 
282  /* Scalar fallback */
283  for (int i = 1; i < n; i++) {
284  if (scores[i] > max_val) {
285  max_val = scores[i];
286  max_idx = i;
287  }
288  }
289 
290  return max_idx;
291 }
void topk_batched_f32(const float *scores, int num_tokens, int n_experts, int k, int *indices, float *weights)
Batched top-K selection for multiple tokens.
Definition: topk_kernels.c:191
int argmax_f32(const float *scores, int n)
Find index of maximum value.
Definition: topk_kernels.c:226
void topk_f32(const float *scores, int n, int k, int *indices, float *values)
Find top-K indices and values from a score vector.
Definition: topk_kernels.c:49
void topk_softmax_f32(const float *scores, int n, int k, int *indices, float *weights)
Find top-K indices with softmax-normalized weights.
Definition: topk_kernels.c:134