30 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__)
31 #include <immintrin.h>
43 if (!a || !b || !y || n == 0) {
49 #if defined(__AVX512F__)
51 for (; i + 16 <= n; i += 16) {
52 __m512 av = bf16_loadu_cvt_fp32(&a[i]);
53 __m512 bv = bf16_loadu_cvt_fp32(&b[i]);
54 __m512 yv = _mm512_add_ps(av, bv);
55 fp32_cvt_storeu_bf16(&y[i], yv);
78 if (!a || !b || !y || n == 0) {
84 #if defined(__AVX512F__)
85 __m512 alpha_v = _mm512_set1_ps(alpha);
86 for (; i + 16 <= n; i += 16) {
87 __m512 av = bf16_loadu_cvt_fp32(&a[i]);
88 __m512 bv = bf16_loadu_cvt_fp32(&b[i]);
89 __m512 yv = _mm512_fmadd_ps(bv, alpha_v, av);
90 fp32_cvt_storeu_bf16(&y[i], yv);
109 if (!a || !b || n == 0) {
115 #if defined(__AVX512F__)
116 for (; i + 16 <= n; i += 16) {
117 __m512 av = bf16_loadu_cvt_fp32(&a[i]);
118 __m512 bv = bf16_loadu_cvt_fp32(&b[i]);
119 __m512 yv = _mm512_add_ps(av, bv);
120 fp32_cvt_storeu_bf16(&a[i], yv);
140 if (!a || !b || n == 0) {
146 #if defined(__AVX512F__)
147 __m512 alpha_v = _mm512_set1_ps(alpha);
148 for (; i + 16 <= n; i += 16) {
149 __m512 av = bf16_loadu_cvt_fp32(&a[i]);
150 __m512 bv = bf16_loadu_cvt_fp32(&b[i]);
151 __m512 yv = _mm512_fmadd_ps(bv, alpha_v, av);
152 fp32_cvt_storeu_bf16(&a[i], yv);
178 if (!d_y || n == 0) {
185 if (d_a && d_a != d_y) {
186 #if defined(__AVX512F__)
187 for (; i + 32 <= n; i += 32) {
188 __m512i v0 = _mm512_loadu_si512((
const __m512i*)&d_y[i]);
189 __m512i v1 = _mm512_loadu_si512((
const __m512i*)&d_y[i + 32]);
190 _mm512_storeu_si512((__m512i*)&d_a[i], v0);
191 _mm512_storeu_si512((__m512i*)&d_a[i + 32], v1);
201 if (d_b && d_b != d_y) {
202 #if defined(__AVX512F__)
203 for (; i + 32 <= n; i += 32) {
204 __m512i v0 = _mm512_loadu_si512((
const __m512i*)&d_y[i]);
205 __m512i v1 = _mm512_loadu_si512((
const __m512i*)&d_y[i + 32]);
206 _mm512_storeu_si512((__m512i*)&d_b[i], v0);
207 _mm512_storeu_si512((__m512i*)&d_b[i + 32], v1);
228 if (!a || !b || !y || tokens <= 0 || dim <= 0) {
232 for (
int t = 0; t < tokens; ++t) {
233 const uint16_t *a_row = a + (size_t)t * aligned_dim;
234 const uint16_t *b_row = b + (size_t)t * aligned_dim;
235 uint16_t *y_row = y + (size_t)t * aligned_dim;
239 #if defined(__AVX512F__)
240 for (; d + 16 <= dim; d += 16) {
241 __m512 av = bf16_loadu_cvt_fp32(&a_row[d]);
242 __m512 bv = bf16_loadu_cvt_fp32(&b_row[d]);
243 __m512 yv = _mm512_add_ps(av, bv);
244 fp32_cvt_storeu_bf16(&y_row[d], yv);
248 for (; d < dim; ++d) {
275 if (!a || !b || !y || n == 0) {
281 #if defined(__AVX512F__)
282 for (; i + 16 <= n; i += 16) {
283 __m512 av = _mm512_loadu_ps(&a[i]);
284 __m512 bv = _mm512_loadu_ps(&b[i]);
285 __m512 yv = _mm512_add_ps(av, bv);
286 _mm512_storeu_ps(&y[i], yv);
290 #if defined(__AVX2__)
291 for (; i + 8 <= n; i += 8) {
292 __m256 av = _mm256_loadu_ps(&a[i]);
293 __m256 bv = _mm256_loadu_ps(&b[i]);
294 __m256 yv = _mm256_add_ps(av, bv);
295 _mm256_storeu_ps(&y[i], yv);
308 if (!a || !b || n == 0) {
314 #if defined(__AVX512F__)
315 for (; i + 16 <= n; i += 16) {
316 __m512 av = _mm512_loadu_ps(&a[i]);
317 __m512 bv = _mm512_loadu_ps(&b[i]);
318 __m512 yv = _mm512_add_ps(av, bv);
319 _mm512_storeu_ps(&a[i], yv);
323 #if defined(__AVX2__)
324 for (; i + 8 <= n; i += 8) {
325 __m256 av = _mm256_loadu_ps(&a[i]);
326 __m256 bv = _mm256_loadu_ps(&b[i]);
327 __m256 yv = _mm256_add_ps(av, bv);
328 _mm256_storeu_ps(&a[i], yv);
void add_scaled_forward_bf16(const uint16_t *a, const uint16_t *b, uint16_t *y, float alpha, size_t n)
void add_inplace_bf16(uint16_t *a, const uint16_t *b, size_t n)
void add_forward_f32(const float *a, const float *b, float *y, size_t n)
void add_inplace_f32(float *a, const float *b, size_t n)
void add_forward_bf16(const uint16_t *a, const uint16_t *b, uint16_t *y, size_t n)
void add_forward_2d_bf16(const uint16_t *a, const uint16_t *b, uint16_t *y, int tokens, int dim, int aligned_dim)
void add_backward_bf16(const uint16_t *d_y, uint16_t *d_a, uint16_t *d_b, size_t n)
void add_scaled_inplace_bf16(uint16_t *a, const uint16_t *b, float alpha, size_t n)
static uint16_t float_to_bf16(float f)
static float bf16_to_float(uint16_t v)