33 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__)
34 #include <immintrin.h>
45 #if defined(__AVX__) || defined(__AVX512F__)
48 __m128 lo = _mm256_castps256_ps128(v);
49 __m128 hi = _mm256_extractf128_ps(v, 1);
50 __m128 sum128 = _mm_add_ps(lo, hi);
51 __m128 shuf = _mm_movehdup_ps(sum128);
52 __m128 sums = _mm_add_ps(sum128, shuf);
53 shuf = _mm_movehl_ps(shuf, sums);
54 sums = _mm_add_ss(sums, shuf);
55 return _mm_cvtss_f32(sums);
59 #if defined(__AVX512F__)
61 static inline float hsum512_ps_fused(__m512 v) {
62 __m256 lo = _mm512_castps512_ps256(v);
63 __m256 hi = _mm512_extractf32x8_ps(v, 1);
64 __m256 sum256 = _mm256_add_ps(lo, hi);
75 float sx = 1.702f * x;
76 float sig = 1.0f / (1.0f + expf(-sx));
91 #pragma omp parallel for
92 for (
int i = 0; i < M; i++) {
93 for (
int j = 0; j < N; j++) {
94 __m256 sum_vec = _mm256_setzero_ps();
96 for (k = 0; k <= K - 8; k += 8) {
97 __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
98 __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
99 __m256 prod = _mm256_mul_ps(a_vec, b_vec);
100 sum_vec = _mm256_add_ps(sum_vec, prod);
104 sum += A[i * K + k] * B[j * K + k];
108 C[i * N + j] = sum > 0.0f ? sum : 0.0f;
112 #pragma omp parallel for
113 for (
int i = 0; i < M; i++) {
114 for (
int j = 0; j < N; j++) {
116 for (
int k = 0; k < K; k++) {
117 sum += A[i * K + k] * B[j * K + k];
120 C[i * N + j] = sum > 0.0f ? sum : 0.0f;
138 #pragma omp parallel for
139 for (
int i = 0; i < M; i++) {
140 for (
int j = 0; j < N; j++) {
141 __m256 sum_vec = _mm256_setzero_ps();
143 for (k = 0; k <= K - 8; k += 8) {
144 __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
145 __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
146 __m256 prod = _mm256_mul_ps(a_vec, b_vec);
147 sum_vec = _mm256_add_ps(sum_vec, prod);
151 sum += A[i * K + k] * B[j * K + k];
158 #pragma omp parallel for
159 for (
int i = 0; i < M; i++) {
160 for (
int j = 0; j < N; j++) {
162 for (
int k = 0; k < K; k++) {
163 sum += A[i * K + k] * B[j * K + k];
184 #pragma omp parallel for
185 for (
int i = 0; i < M; i++) {
186 for (
int j = 0; j < N; j++) {
187 __m256 sum_vec = _mm256_setzero_ps();
189 for (k = 0; k <= K - 8; k += 8) {
190 __m256 a_vec = _mm256_loadu_ps(&A[i * K + k]);
191 __m256 b_vec = _mm256_loadu_ps(&B[j * K + k]);
192 __m256 prod = _mm256_mul_ps(a_vec, b_vec);
193 sum_vec = _mm256_add_ps(sum_vec, prod);
197 sum += A[i * K + k] * B[j * K + k];
201 float sig = 1.0f / (1.0f + expf(-sum));
202 C[i * N + j] = sum * sig;
206 #pragma omp parallel for
207 for (
int i = 0; i < M; i++) {
208 for (
int j = 0; j < N; j++) {
210 for (
int k = 0; k < K; k++) {
211 sum += A[i * K + k] * B[j * K + k];
214 float sig = 1.0f / (1.0f + expf(-sum));
215 C[i * N + j] = sum * sig;
250 #pragma omp parallel for
251 for (
int i = 0; i < M; i++) {
252 const float *x_row = &x[i * K];
253 float *out_row = &output[i * N];
255 for (
int j = 0; j < N; j++) {
256 const float *w_gate_row = &W_gate[j * K];
257 const float *w_up_row = &W_up[j * K];
260 __m256 gate_vec = _mm256_setzero_ps();
261 __m256 up_vec = _mm256_setzero_ps();
264 for (k = 0; k <= K - 8; k += 8) {
265 __m256 x_vec = _mm256_loadu_ps(&x_row[k]);
266 __m256 wg_vec = _mm256_loadu_ps(&w_gate_row[k]);
267 __m256 wu_vec = _mm256_loadu_ps(&w_up_row[k]);
270 gate_vec = _mm256_add_ps(gate_vec, _mm256_mul_ps(x_vec, wg_vec));
272 up_vec = _mm256_add_ps(up_vec, _mm256_mul_ps(x_vec, wu_vec));
281 gate += x_row[k] * w_gate_row[k];
282 up += x_row[k] * w_up_row[k];
286 if (b_gate) gate += b_gate[j];
287 if (b_up) up += b_up[j];
290 float sig = 1.0f / (1.0f + expf(-gate));
291 out_row[j] = gate * sig * up;
296 #pragma omp parallel for
297 for (
int i = 0; i < M; i++) {
298 for (
int j = 0; j < N; j++) {
302 for (
int k = 0; k < K; k++) {
303 gate += x[i * K + k] * W_gate[j * K + k];
304 up += x[i * K + k] * W_up[j * K + k];
307 if (b_gate) gate += b_gate[j];
308 if (b_up) up += b_up[j];
311 float sig = 1.0f / (1.0f + expf(-gate));
312 output[i * N + j] = gate * sig * up;
void gemm_bias_silu_fused(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
void gemm_swiglu_fused(const float *x, const float *W_gate, const float *W_up, const float *b_gate, const float *b_up, float *output, int M, int N, int K)
void gemm_bias_relu_fused(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
static float fast_gelu_scalar(float x)
void gemm_bias_gelu_fused(const float *A, const float *B, const float *bias, float *C, int M, int N, int K)
static float hsum256_ps_fused(__m256 v)