28 #include <immintrin.h>
55 if (!scores || !indices || n <= 0 || k <= 0) {
65 float local_values[k];
66 for (
int i = 0; i < k; i++) {
68 local_values[i] = scores[i];
73 for (
int i = 1; i < k; i++) {
74 if (local_values[i] < local_values[min_idx]) {
80 for (
int i = k; i < n; i++) {
81 if (scores[i] > local_values[min_idx]) {
84 local_values[min_idx] = scores[i];
88 for (
int j = 1; j < k; j++) {
89 if (local_values[j] < local_values[min_idx]) {
97 for (
int i = 1; i < k; i++) {
98 float val = local_values[i];
101 while (j >= 0 && local_values[j] < val) {
102 local_values[j + 1] = local_values[j];
103 indices[j + 1] = indices[j];
106 local_values[j + 1] = val;
107 indices[j + 1] = idx;
112 for (
int i = 0; i < k; i++) {
113 values[i] = local_values[i];
140 if (!scores || !indices || !weights || n <= 0 || k <= 0) {
150 topk_f32(scores, n, k, indices, values);
154 float max_val = values[0];
155 for (
int i = 1; i < k; i++) {
156 if (values[i] > max_val) {
163 for (
int i = 0; i < k; i++) {
164 weights[i] = expf(values[i] - max_val);
169 float inv_sum = 1.0f / sum;
170 for (
int i = 0; i < k; i++) {
171 weights[i] *= inv_sum;
198 if (!scores || !indices || num_tokens <= 0 || n_experts <= 0 || k <= 0) {
202 for (
int t = 0; t < num_tokens; t++) {
203 const float *token_scores = scores + t * n_experts;
204 int *token_indices = indices + t * k;
207 float *token_weights = weights + t * k;
210 topk_f32(token_scores, n_experts, k, token_indices, NULL);
228 if (!scores || n <= 0) {
233 float max_val = scores[0];
238 __m512 vmax = _mm512_set1_ps(-FLT_MAX);
239 __m512i vidx = _mm512_setzero_si512();
240 __m512i vcur_max_idx = _mm512_setzero_si512();
243 for (; i + 16 <= n; i += 16) {
244 __m512 v = _mm512_loadu_ps(&scores[i]);
245 __m512i cur_idx = _mm512_add_epi32(
246 _mm512_set1_epi32(i),
247 _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
250 __mmask16 gt_mask = _mm512_cmp_ps_mask(v, vmax, _CMP_GT_OQ);
251 vmax = _mm512_mask_blend_ps(gt_mask, vmax, v);
252 vcur_max_idx = _mm512_mask_blend_epi32(gt_mask, vcur_max_idx, cur_idx);
258 _mm512_storeu_ps(vals, vmax);
259 _mm512_storeu_si512(idxs, vcur_max_idx);
263 for (
int j = 1; j < 16; j++) {
264 if (vals[j] > max_val) {
272 if (scores[i] > max_val) {
283 for (
int i = 1; i < n; i++) {
284 if (scores[i] > max_val) {
void topk_batched_f32(const float *scores, int num_tokens, int n_experts, int k, int *indices, float *weights)
Batched top-K selection for multiple tokens.
int argmax_f32(const float *scores, int n)
Find index of maximum value.
void topk_f32(const float *scores, int n, int k, int *indices, float *values)
Find top-K indices and values from a score vector.
void topk_softmax_f32(const float *scores, int n, int k, int *indices, float *weights)
Find top-K indices with softmax-normalized weights.