7 #if defined(__AVX512F__)
44 tmp.u = (uint32_t)v << 16;
98 uint32_t lsb = (tmp.u >> 16) & 1u;
100 tmp.u += 0x7FFFu + lsb;
102 return (uint16_t)(tmp.u >> 16);
137 #if defined(__AVX512F__)
146 static inline __m512 bf16x16_to_fp32(__m256i bf16_vec)
149 __m512i as_int = _mm512_cvtepu16_epi32(bf16_vec);
151 __m512i shifted = _mm512_slli_epi32(as_int, 16);
153 return _mm512_castsi512_ps(shifted);
166 #if defined(__AVX512BF16__)
172 static inline __m256i fp32x16_to_bf16(__m512 fp32_vec)
177 __m256bh bf16_result = _mm512_cvtneps_pbh(fp32_vec);
178 return (__m256i)bf16_result;
185 static inline __m256i fp32x16_to_bf16(__m512 fp32_vec)
188 __m512i as_int = _mm512_castps_si512(fp32_vec);
191 __m512i lsb = _mm512_srli_epi32(as_int, 16);
192 lsb = _mm512_and_si512(lsb, _mm512_set1_epi32(1));
195 __m512i rounding = _mm512_add_epi32(_mm512_set1_epi32(0x7FFF), lsb);
198 __m512i rounded = _mm512_add_epi32(as_int, rounding);
199 __m512i shifted = _mm512_srli_epi32(rounded, 16);
202 return _mm512_cvtepi32_epi16(shifted);
207 static inline __m512 bf16_loadu_cvt_fp32(
const uint16_t *ptr)
209 __m256i bf16_vec = _mm256_loadu_si256((
const __m256i *)ptr);
210 return bf16x16_to_fp32(bf16_vec);
214 static inline void fp32_cvt_storeu_bf16(uint16_t *ptr, __m512 fp32_vec)
216 __m256i bf16_vec = fp32x16_to_bf16(fp32_vec);
217 _mm256_storeu_si256((__m256i *)ptr, bf16_vec);
252 #if defined(__AVX512F__)
254 for (; i + 16 <= count; i += 16) {
255 __m512 fp32_vec = bf16_loadu_cvt_fp32(&src[i]);
256 _mm512_storeu_ps(&dst[i], fp32_vec);
258 for (; i < count; ++i) {
262 for (
size_t i = 0; i < count; ++i) {
273 #if defined(__AVX512F__)
275 for (; i + 16 <= count; i += 16) {
276 __m512 fp32_vec = _mm512_loadu_ps(&src[i]);
277 fp32_cvt_storeu_bf16(&dst[i], fp32_vec);
279 for (; i < count; ++i) {
283 for (
size_t i = 0; i < count; ++i) {
static void float_tensor_to_bf16(const float *src, uint16_t *dst, size_t count)
static uint16_t float_to_bf16(float f)
static float bf16_to_float(uint16_t v)
static void bf16_tensor_to_float(const uint16_t *src, float *dst, size_t count)