24 #if defined(__AVX512F__)
25 #include <immintrin.h>
36 #pragma GCC diagnostic push
37 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
41 const float *bias,
float *
C,
47 const float c = 0.7978845608f;
48 const float k = 0.044715f;
50 return 0.5f * x * (1.0f + tanhf(c * (x + k * x3)));
53 #if defined(__AVX512F__)
55 static inline __m512 gelu_avx512(__m512 x)
57 const __m512 c = _mm512_set1_ps(0.7978845608f);
58 const __m512 k = _mm512_set1_ps(0.044715f);
59 const __m512 half = _mm512_set1_ps(0.5f);
60 const __m512 one = _mm512_set1_ps(1.0f);
62 __m512 x2 = _mm512_mul_ps(x, x);
63 __m512 x3 = _mm512_mul_ps(x2, x);
65 __m512 inner = _mm512_fmadd_ps(k, x3, x);
66 inner = _mm512_mul_ps(c, inner);
68 __m512 inner2 = _mm512_mul_ps(inner, inner);
69 __m512 num = _mm512_add_ps(_mm512_set1_ps(27.0f), inner2);
70 __m512 den = _mm512_fmadd_ps(_mm512_set1_ps(9.0f), inner2, _mm512_set1_ps(27.0f));
71 __m512 tanh_approx = _mm512_mul_ps(inner, _mm512_div_ps(num, den));
73 tanh_approx = _mm512_min_ps(tanh_approx, one);
74 tanh_approx = _mm512_max_ps(tanh_approx, _mm512_set1_ps(-1.0f));
76 __m512 result = _mm512_add_ps(one, tanh_approx);
77 result = _mm512_mul_ps(half, _mm512_mul_ps(x, result));
92 const uint16_t *W_fc1,
93 const uint16_t *b_fc1,
94 const uint16_t *W_fc2,
95 const uint16_t *b_fc2,
101 float *scratch_bias1_f,
102 float *scratch_bias2_f,
103 uint16_t *scratch_fc1_bf16)
105 if (!input || !W_fc1 || !b_fc1 || !W_fc2 || !b_fc2 || !fc1_output || !output)
return;
106 if (!scratch_bias1_f || !scratch_bias2_f || !scratch_fc1_bf16)
return;
109 const int D = aligned_dim;
110 const int fourD = 4 * D;
113 for (
int i = 0; i < fourD; ++i) {
116 for (
int i = 0; i < D; ++i) {
124 #if defined(__AVX512F__)
125 #pragma omp parallel for
126 for (
int t = 0; t < T; ++t) {
127 float *row = fc1_output + (size_t)t * fourD;
129 for (; j <= fourD - 16; j += 16) {
130 __m512 x = _mm512_loadu_ps(row + j);
131 _mm512_storeu_ps(row + j, gelu_avx512(x));
133 for (; j < fourD; ++j) {
138 for (
int t = 0; t < T; ++t) {
139 for (
int j = 0; j < fourD; ++j) {
140 fc1_output[t * fourD + j] =
gelu_scalar(fc1_output[t * fourD + j]);
146 #if defined(__AVX512F__)
147 #pragma omp parallel for
148 for (
int t = 0; t < T; ++t) {
149 float *src = fc1_output + (size_t)t * fourD;
150 uint16_t *dst = scratch_fc1_bf16 + (size_t)t * fourD;
152 for (; j <= fourD - 16; j += 16) {
153 __m512 fp32 = _mm512_loadu_ps(src + j);
154 __m512i as_int = _mm512_castps_si512(fp32);
155 __m512i lsb = _mm512_srli_epi32(as_int, 16);
156 lsb = _mm512_and_si512(lsb, _mm512_set1_epi32(1));
157 __m512i rounding = _mm512_add_epi32(_mm512_set1_epi32(0x7FFF), lsb);
158 __m512i rounded = _mm512_add_epi32(as_int, rounding);
159 __m512i shifted = _mm512_srli_epi32(rounded, 16);
160 __m256i bf16 = _mm512_cvtepi32_epi16(shifted);
161 _mm256_storeu_si256((__m256i *)(dst + j), bf16);
163 for (; j < fourD; ++j) {
168 for (
size_t i = 0; i < (size_t)T * fourD; ++i) {
174 gemm_bf16_fp32out(scratch_fc1_bf16, W_fc2, scratch_bias2_f, output, T, D, fourD);
187 const uint16_t *W_fc1,
188 const uint16_t *b_fc1,
189 const uint16_t *W_fc2,
190 const uint16_t *b_fc2,
196 float *scratch_input_f,
197 float *scratch_bias1_f,
198 float *scratch_bias2_f,
199 uint16_t *scratch_fc1_bf16)
201 if (!input || !W_fc1 || !b_fc1 || !W_fc2 || !b_fc2 || !fc1_output || !output)
return;
202 if (!scratch_input_f || !scratch_bias1_f || !scratch_bias2_f || !scratch_fc1_bf16)
return;
205 const int D = aligned_dim;
206 const int fourD = 4 * D;
217 #if defined(__AVX512F__)
218 #pragma omp parallel for
219 for (
int t = 0; t < T; ++t) {
220 float *row = fc1_output + (size_t)t * fourD;
222 for (; j <= fourD - 16; j += 16) {
223 __m512 x = _mm512_loadu_ps(row + j);
224 _mm512_storeu_ps(row + j, gelu_avx512(x));
226 for (; j < fourD; ++j) {
231 for (
size_t i = 0; i < (size_t)T * fourD; ++i) {
238 gemm_bf16_fp32out(scratch_fc1_bf16, W_fc2, scratch_bias2_f, output, T, D, fourD);
241 #pragma GCC diagnostic pop
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
static uint16_t float_to_bf16(float f)
static float bf16_to_float(uint16_t v)
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)
static float gelu_scalar(float x)
void mlp_token_parallel_bf16(const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16)
void mlp_token_parallel_bf16_fp32act(const uint16_t *input, const uint16_t *W_fc1, const uint16_t *b_fc1, const uint16_t *W_fc2, const uint16_t *b_fc2, float *fc1_output, float *output, int T, int aligned_dim, int num_threads, float *scratch_input_f, float *scratch_bias1_f, float *scratch_bias2_f, uint16_t *scratch_fc1_bf16)
void gemm_bf16_fp32out(const uint16_t *A, const uint16_t *B, const float *bias, float *C, int M, int N, int K)