28 #include <immintrin.h>
32 #include <immintrin.h>
59 if (!y || !x || n <= 0) {
66 __m512 valpha = _mm512_set1_ps(alpha);
67 for (; i + 16 <= n; i += 16) {
68 __m512 vy = _mm512_loadu_ps(&y[i]);
69 __m512 vx = _mm512_loadu_ps(&x[i]);
70 vy = _mm512_fmadd_ps(vx, valpha, vy);
71 _mm512_storeu_ps(&y[i], vy);
76 __m256 valpha256 = _mm256_set1_ps(alpha);
77 for (; i + 8 <= n; i += 8) {
78 __m256 vy = _mm256_loadu_ps(&y[i]);
79 __m256 vx = _mm256_loadu_ps(&x[i]);
80 vy = _mm256_fmadd_ps(vx, valpha256, vy);
81 _mm256_storeu_ps(&y[i], vy);
110 if (!y || !x || n <= 0) {
117 __m512 valpha = _mm512_set1_ps(alpha);
118 for (; i + 16 <= n; i += 16) {
119 __m512 vx = _mm512_loadu_ps(&x[i]);
120 __m512 vy = _mm512_mul_ps(vx, valpha);
121 _mm512_storeu_ps(&y[i], vy);
126 __m256 valpha256 = _mm256_set1_ps(alpha);
127 for (; i + 8 <= n; i += 8) {
128 __m256 vx = _mm256_loadu_ps(&x[i]);
129 __m256 vy = _mm256_mul_ps(vx, valpha256);
130 _mm256_storeu_ps(&y[i], vy);
156 const float **vectors,
157 const float *weights,
161 if (!y || !vectors || !weights || k <= 0 || n <= 0) {
169 for (
int i = 1; i < k; i++) {
170 axpy_f32(y, vectors[i], weights[i], n);
197 memset(y, 0, n *
sizeof(
float));
229 if (!Y || !X || num_tokens <= 0 || dim <= 0) {
234 if (y_stride <= 0) y_stride = dim;
235 if (x_stride <= 0) x_stride = dim;
237 for (
int t = 0; t < num_tokens; t++) {
238 axpy_f32(Y + t * y_stride, X + t * x_stride, alpha, dim);
257 const float *expert_output,
258 float routing_weight,
261 axpy_f32(output, expert_output, routing_weight, hidden_dim);
void axpy_f32(float *y, const float *x, float alpha, int n)
In-place AXPY: y += alpha * x.
void moe_accumulate_expert_f32(float *output, const float *expert_output, float routing_weight, int hidden_dim)
Accumulate expert output: output += routing_weight * expert_output.
void axpy_zero_f32(float *y, const float *x, float alpha, int n)
Zero output then accumulate: y = 0; y += alpha * x.
void weighted_sum_f32(float *y, const float **vectors, const float *weights, int k, int n)
Weighted sum of k vectors: y = sum_i(weights[i] * vectors[i])
void axpy_2d_f32(float *Y, const float *X, float alpha, int num_tokens, int dim, int y_stride, int x_stride)
Batched AXPY for 2D tensors: Y[t,:] += alpha * X[t,:].
void scal_copy_f32(float *y, const float *x, float alpha, int n)
Scaled copy: y = alpha * x.